def main(args):

    if args.device is None:
        args.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    device = torch.device(args.device)

    # Pretrained imagenet model
    model = models.__dict__[args.model](pretrained=True).eval().to(device=device)
    conv_layer = MODEL_CONFIG[args.model]['conv_layer']
    input_layer = MODEL_CONFIG[args.model]['input_layer']
    fc_layer = MODEL_CONFIG[args.model]['fc_layer']

    # Image
    if args.img.startswith('http'):
        img_path = BytesIO(requests.get(args.img).content)
    pil_img = Image.open(img_path, mode='r').convert('RGB')

    # Preprocess image
    img_tensor = normalize(to_tensor(resize(pil_img, (224, 224))),
                           [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).to(device=device)

    # Hook the corresponding layer in the model
    cam_extractors = [CAM(model, conv_layer, fc_layer), GradCAM(model, conv_layer),
                      GradCAMpp(model, conv_layer), SmoothGradCAMpp(model, conv_layer, input_layer),
                      ScoreCAM(model, conv_layer, input_layer), SSCAM(model, conv_layer, input_layer),
                      ISSCAM(model, conv_layer, input_layer)]

    # Don't trigger all hooks
    for extractor in cam_extractors:
        extractor._hooks_enabled = False

    fig, axes = plt.subplots(1, len(cam_extractors), figsize=(7, 2))
    for idx, extractor in enumerate(cam_extractors):
        extractor._hooks_enabled = True
        model.zero_grad()
        scores = model(img_tensor.unsqueeze(0))

        # Select the class index
        class_idx = scores.squeeze(0).argmax().item() if args.class_idx is None else args.class_idx

        # Use the hooked data to compute activation map
        activation_map = extractor(class_idx, scores).cpu()
        # Clean data
        extractor.clear_hooks()
        extractor._hooks_enabled = False
        # Convert it to PIL image
        # The indexing below means first image in batch
        heatmap = to_pil_image(activation_map, mode='F')
        # Plot the result
        result = overlay_mask(pil_img, heatmap)

        axes[idx].imshow(result)
        axes[idx].axis('off')
        axes[idx].set_title(extractor.__class__.__name__, size=8)

    plt.tight_layout()
    if args.savefig:
        plt.savefig(args.savefig, dpi=200, transparent=True, bbox_inches='tight', pad_inches=0)
    plt.show()
示例#2
0
    def test_overlay_mask(self):

        img = Image.fromarray(np.zeros((4, 4, 3)).astype(np.uint8))
        mask = Image.fromarray(255 * np.ones((4, 4)).astype(np.uint8))

        overlayed = utils.overlay_mask(img, mask, alpha=0.7)

        # Check object type
        self.assertIsInstance(overlayed, Image.Image)
        # Verify value
        self.assertTrue(np.all(np.asarray(overlayed)[..., 0] == 0))
        self.assertTrue(np.all(np.asarray(overlayed)[..., 1] == 39))
        self.assertTrue(np.all(np.asarray(overlayed)[..., 2] == 76))
示例#3
0
def visualize_cam_on_img(img_name, model):
    cam_extractor = SmoothGradCAMpp(model)
    img = read_image(str(img_name))
    input_tensor = normalize(
        resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]).cuda()
    out = model(input_tensor.unsqueeze(0))
    activation_map = cam_extractor(out.squeeze(0).argmax().item(), out)
    result = overlay_mask(to_pil_image(img),
                          to_pil_image(activation_map, mode='F'),
                          alpha=0.5)
    plt.imshow(result)
    plt.axis('off')
    plt.tight_layout()
    plt.show()
示例#4
0
def main(args):

    if args.device is None:
        args.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    device = torch.device(args.device)

    # Pretrained imagenet model
    model = models.__dict__[args.model](pretrained=True).to(device=device)

    # Image
    if args.img.startswith('http'):
        img_path = BytesIO(requests.get(args.img).content)
    else:
        img_path = args.img
    pil_img = Image.open(img_path, mode='r').convert('RGB')

    # Preprocess image
    img_tensor = normalize(to_tensor(resize(pil_img, (224, 224))),
                           [0.485, 0.456, 0.406],
                           [0.229, 0.224, 0.225]).to(device=device)

    if isinstance(args.method, str):
        methods = [args.method]
    else:
        methods = [
            'CAM',
            'GradCAM',
            'GradCAMpp',
            'SmoothGradCAMpp',
            'ScoreCAM',
            'SSCAM',
            'ISCAM',
            'XGradCAM',
        ]
    # Hook the corresponding layer in the model
    cam_extractors = [cams.__dict__[name](model) for name in methods]

    # Don't trigger all hooks
    for extractor in cam_extractors:
        extractor._hooks_enabled = False

    # Homogenize number of elements in each row
    num_cols = math.ceil((len(cam_extractors) + 1) / args.rows)
    _, axes = plt.subplots(args.rows, num_cols, figsize=(6, 4))
    # Display input
    ax = axes[0][0] if args.rows > 1 else axes[0] if num_cols > 1 else axes
    ax.imshow(pil_img)
    ax.set_title("Input", size=8)

    for idx, extractor in zip(range(1,
                                    len(cam_extractors) + 1), cam_extractors):
        extractor._hooks_enabled = True
        model.zero_grad()
        scores = model(img_tensor.unsqueeze(0))

        # Select the class index
        class_idx = scores.squeeze(
            0).argmax().item() if args.class_idx is None else args.class_idx

        # Use the hooked data to compute activation map
        activation_map = extractor(class_idx, scores).cpu()

        # Clean data
        extractor.clear_hooks()
        extractor._hooks_enabled = False
        # Convert it to PIL image
        # The indexing below means first image in batch
        heatmap = to_pil_image(activation_map, mode='F')
        # Plot the result
        result = overlay_mask(pil_img, heatmap, alpha=args.alpha)

        ax = axes[idx // num_cols][
            idx %
            num_cols] if args.rows > 1 else axes[idx] if num_cols > 1 else axes

        ax.imshow(result)
        ax.set_title(extractor.__class__.__name__, size=8)

    # Clear axes
    if num_cols > 1:
        for _axes in axes:
            if args.rows > 1:
                for ax in _axes:
                    ax.axis('off')
            else:
                _axes.axis('off')

    else:
        axes.axis('off')

    plt.tight_layout()
    if args.savefig:
        plt.savefig(args.savefig,
                    dpi=200,
                    transparent=True,
                    bbox_inches='tight',
                    pad_inches=0)
    plt.show()
示例#5
0
def main():

    # Wide mode
    st.set_page_config(layout="wide")

    # Designing the interface
    st.title("TorchCAM: class activation explorer")
    # For newline
    st.write('\n')
    # Set the columns
    cols = st.beta_columns((1, 1, 1))
    cols[0].header("Input image")
    cols[1].header("Raw CAM")
    cols[-1].header("Overlayed CAM")

    # Sidebar
    # File selection
    st.sidebar.title("Input selection")
    # Disabling warning
    st.set_option('deprecation.showfileUploaderEncoding', False)
    # Choose your own image
    uploaded_file = st.sidebar.file_uploader("Upload files",
                                             type=['png', 'jpeg', 'jpg'])
    if uploaded_file is not None:
        img = Image.open(BytesIO(uploaded_file.read()),
                         mode='r').convert('RGB')

        cols[0].image(img, use_column_width=True)

    # Model selection
    st.sidebar.title("Setup")
    tv_model = st.sidebar.selectbox("Classification model", TV_MODELS)
    default_layer = ""
    if tv_model is not None:
        with st.spinner('Loading model...'):
            model = models.__dict__[tv_model](pretrained=True).eval()
        default_layer = cams.utils.locate_candidate_layer(model, (3, 224, 224))

    target_layer = st.sidebar.text_input("Target layer", default_layer)
    cam_method = st.sidebar.selectbox("CAM method", CAM_METHODS)
    if cam_method is not None:
        cam_extractor = cams.__dict__[cam_method](
            model,
            target_layer=target_layer if len(target_layer) > 0 else None)

    class_choices = [
        f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)
    ]
    class_selection = st.sidebar.selectbox(
        "Class selection", ["Predicted class (argmax)"] + class_choices)

    # For newline
    st.sidebar.write('\n')

    if st.sidebar.button("Compute CAM"):

        if uploaded_file is None:
            st.sidebar.error("Please upload an image first")

        else:
            with st.spinner('Analyzing...'):

                # Preprocess image
                img_tensor = normalize(to_tensor(resize(img, (224, 224))),
                                       [0.485, 0.456, 0.406],
                                       [0.229, 0.224, 0.225])

                # Forward the image to the model
                out = model(img_tensor.unsqueeze(0))
                # Select the target class
                if class_selection == "Predicted class (argmax)":
                    class_idx = out.squeeze(0).argmax().item()
                else:
                    class_idx = LABEL_MAP.index(
                        class_selection.rpartition(" - ")[-1])
                # Retrieve the CAM
                activation_map = cam_extractor(class_idx, out)
                # Plot the raw heatmap
                fig, ax = plt.subplots()
                ax.imshow(activation_map.numpy())
                ax.axis('off')
                cols[1].pyplot(fig)

                # Overlayed CAM
                fig, ax = plt.subplots()
                result = overlay_mask(img,
                                      to_pil_image(activation_map, mode='F'),
                                      alpha=0.5)
                ax.imshow(result)
                ax.axis('off')
                cols[-1].pyplot(fig)