def main(args): if args.device is None: args.device = 'cuda:0' if torch.cuda.is_available() else 'cpu' device = torch.device(args.device) # Pretrained imagenet model model = models.__dict__[args.model](pretrained=True).eval().to(device=device) conv_layer = MODEL_CONFIG[args.model]['conv_layer'] input_layer = MODEL_CONFIG[args.model]['input_layer'] fc_layer = MODEL_CONFIG[args.model]['fc_layer'] # Image if args.img.startswith('http'): img_path = BytesIO(requests.get(args.img).content) pil_img = Image.open(img_path, mode='r').convert('RGB') # Preprocess image img_tensor = normalize(to_tensor(resize(pil_img, (224, 224))), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).to(device=device) # Hook the corresponding layer in the model cam_extractors = [CAM(model, conv_layer, fc_layer), GradCAM(model, conv_layer), GradCAMpp(model, conv_layer), SmoothGradCAMpp(model, conv_layer, input_layer), ScoreCAM(model, conv_layer, input_layer), SSCAM(model, conv_layer, input_layer), ISSCAM(model, conv_layer, input_layer)] # Don't trigger all hooks for extractor in cam_extractors: extractor._hooks_enabled = False fig, axes = plt.subplots(1, len(cam_extractors), figsize=(7, 2)) for idx, extractor in enumerate(cam_extractors): extractor._hooks_enabled = True model.zero_grad() scores = model(img_tensor.unsqueeze(0)) # Select the class index class_idx = scores.squeeze(0).argmax().item() if args.class_idx is None else args.class_idx # Use the hooked data to compute activation map activation_map = extractor(class_idx, scores).cpu() # Clean data extractor.clear_hooks() extractor._hooks_enabled = False # Convert it to PIL image # The indexing below means first image in batch heatmap = to_pil_image(activation_map, mode='F') # Plot the result result = overlay_mask(pil_img, heatmap) axes[idx].imshow(result) axes[idx].axis('off') axes[idx].set_title(extractor.__class__.__name__, size=8) plt.tight_layout() if args.savefig: plt.savefig(args.savefig, dpi=200, transparent=True, bbox_inches='tight', pad_inches=0) plt.show()
def test_overlay_mask(self): img = Image.fromarray(np.zeros((4, 4, 3)).astype(np.uint8)) mask = Image.fromarray(255 * np.ones((4, 4)).astype(np.uint8)) overlayed = utils.overlay_mask(img, mask, alpha=0.7) # Check object type self.assertIsInstance(overlayed, Image.Image) # Verify value self.assertTrue(np.all(np.asarray(overlayed)[..., 0] == 0)) self.assertTrue(np.all(np.asarray(overlayed)[..., 1] == 39)) self.assertTrue(np.all(np.asarray(overlayed)[..., 2] == 76))
def visualize_cam_on_img(img_name, model): cam_extractor = SmoothGradCAMpp(model) img = read_image(str(img_name)) input_tensor = normalize( resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).cuda() out = model(input_tensor.unsqueeze(0)) activation_map = cam_extractor(out.squeeze(0).argmax().item(), out) result = overlay_mask(to_pil_image(img), to_pil_image(activation_map, mode='F'), alpha=0.5) plt.imshow(result) plt.axis('off') plt.tight_layout() plt.show()
def main(args): if args.device is None: args.device = 'cuda:0' if torch.cuda.is_available() else 'cpu' device = torch.device(args.device) # Pretrained imagenet model model = models.__dict__[args.model](pretrained=True).to(device=device) # Image if args.img.startswith('http'): img_path = BytesIO(requests.get(args.img).content) else: img_path = args.img pil_img = Image.open(img_path, mode='r').convert('RGB') # Preprocess image img_tensor = normalize(to_tensor(resize(pil_img, (224, 224))), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).to(device=device) if isinstance(args.method, str): methods = [args.method] else: methods = [ 'CAM', 'GradCAM', 'GradCAMpp', 'SmoothGradCAMpp', 'ScoreCAM', 'SSCAM', 'ISCAM', 'XGradCAM', ] # Hook the corresponding layer in the model cam_extractors = [cams.__dict__[name](model) for name in methods] # Don't trigger all hooks for extractor in cam_extractors: extractor._hooks_enabled = False # Homogenize number of elements in each row num_cols = math.ceil((len(cam_extractors) + 1) / args.rows) _, axes = plt.subplots(args.rows, num_cols, figsize=(6, 4)) # Display input ax = axes[0][0] if args.rows > 1 else axes[0] if num_cols > 1 else axes ax.imshow(pil_img) ax.set_title("Input", size=8) for idx, extractor in zip(range(1, len(cam_extractors) + 1), cam_extractors): extractor._hooks_enabled = True model.zero_grad() scores = model(img_tensor.unsqueeze(0)) # Select the class index class_idx = scores.squeeze( 0).argmax().item() if args.class_idx is None else args.class_idx # Use the hooked data to compute activation map activation_map = extractor(class_idx, scores).cpu() # Clean data extractor.clear_hooks() extractor._hooks_enabled = False # Convert it to PIL image # The indexing below means first image in batch heatmap = to_pil_image(activation_map, mode='F') # Plot the result result = overlay_mask(pil_img, heatmap, alpha=args.alpha) ax = axes[idx // num_cols][ idx % num_cols] if args.rows > 1 else axes[idx] if num_cols > 1 else axes ax.imshow(result) ax.set_title(extractor.__class__.__name__, size=8) # Clear axes if num_cols > 1: for _axes in axes: if args.rows > 1: for ax in _axes: ax.axis('off') else: _axes.axis('off') else: axes.axis('off') plt.tight_layout() if args.savefig: plt.savefig(args.savefig, dpi=200, transparent=True, bbox_inches='tight', pad_inches=0) plt.show()
def main(): # Wide mode st.set_page_config(layout="wide") # Designing the interface st.title("TorchCAM: class activation explorer") # For newline st.write('\n') # Set the columns cols = st.beta_columns((1, 1, 1)) cols[0].header("Input image") cols[1].header("Raw CAM") cols[-1].header("Overlayed CAM") # Sidebar # File selection st.sidebar.title("Input selection") # Disabling warning st.set_option('deprecation.showfileUploaderEncoding', False) # Choose your own image uploaded_file = st.sidebar.file_uploader("Upload files", type=['png', 'jpeg', 'jpg']) if uploaded_file is not None: img = Image.open(BytesIO(uploaded_file.read()), mode='r').convert('RGB') cols[0].image(img, use_column_width=True) # Model selection st.sidebar.title("Setup") tv_model = st.sidebar.selectbox("Classification model", TV_MODELS) default_layer = "" if tv_model is not None: with st.spinner('Loading model...'): model = models.__dict__[tv_model](pretrained=True).eval() default_layer = cams.utils.locate_candidate_layer(model, (3, 224, 224)) target_layer = st.sidebar.text_input("Target layer", default_layer) cam_method = st.sidebar.selectbox("CAM method", CAM_METHODS) if cam_method is not None: cam_extractor = cams.__dict__[cam_method]( model, target_layer=target_layer if len(target_layer) > 0 else None) class_choices = [ f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP) ] class_selection = st.sidebar.selectbox( "Class selection", ["Predicted class (argmax)"] + class_choices) # For newline st.sidebar.write('\n') if st.sidebar.button("Compute CAM"): if uploaded_file is None: st.sidebar.error("Please upload an image first") else: with st.spinner('Analyzing...'): # Preprocess image img_tensor = normalize(to_tensor(resize(img, (224, 224))), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Forward the image to the model out = model(img_tensor.unsqueeze(0)) # Select the target class if class_selection == "Predicted class (argmax)": class_idx = out.squeeze(0).argmax().item() else: class_idx = LABEL_MAP.index( class_selection.rpartition(" - ")[-1]) # Retrieve the CAM activation_map = cam_extractor(class_idx, out) # Plot the raw heatmap fig, ax = plt.subplots() ax.imshow(activation_map.numpy()) ax.axis('off') cols[1].pyplot(fig) # Overlayed CAM fig, ax = plt.subplots() result = overlay_mask(img, to_pil_image(activation_map, mode='F'), alpha=0.5) ax.imshow(result) ax.axis('off') cols[-1].pyplot(fig)