Exemplo n.º 1
0
def compute_gradcam_gif(cam, x, x_un_normalized):
    gradcam_output_buffer = BytesIO()

    new_cam = util.resize(cam, x[0])
    input_np = np.transpose(x_un_normalized[0], (1, 2, 3, 0))
    input_normed = np.float32(input_np) / 255

    cam_frames = list(util.add_heat_map(input_normed, new_cam))
    cam_frames = [Image.fromarray(frame) for frame in cam_frames]
    cam_frames[0].save(
        gradcam_output_buffer,
        save_all=True,
        append_images=cam_frames[1:] if len(cam_frames) > 1 else [],
        format="GIF",
    )

    return gradcam_output_buffer
Exemplo n.º 2
0
def write_grad_cams(input_,
                    label,
                    grad_cam,
                    directory,
                    task_sequence,
                    only_competition=False,
                    only_top_task=False,
                    view_id=None):
    """Creates a CAM for each image.

        Args:
            input: Image tensor with shape (3 x h x h)
            grad_cam: EnsembleCam Object wrapped around GradCam objects, which are wrapped around models.
            directory: the output folder for these set of cams
            task_sequence:
    """
    if only_competition:
        COMPETITION_TASKS = TASK_SEQUENCES['competition']

    # Get the original image by
    # unnormalizing (img pixels will be between 0 and 1)
    # img shape: c, h, w
    img = util.un_normalize(input_, IMAGENET_MEAN, IMAGENET_STD)

    # move rgb chanel to last
    img = np.moveaxis(img, 0, 2)

    # Add the batch dimension
    # as the model requires it.
    input_ = input_.unsqueeze(0)
    _, channels, height, width = input_.shape
    num_tasks = len(task_sequence)

    # Create the directory for cams for this specific example
    if not os.path.exists(directory):
        os.makedirs(directory)

    #assert (inputs.shape[0] == 1), 'batch size must be equal to 1'
    with torch.set_grad_enabled(True):

        for task_id in range(num_tasks):
            task_name = list(task_sequence)[task_id]
            if only_competition:
                if task_name not in COMPETITION_TASKS:
                    continue

            task = task_name.lower()
            task = task.replace(' ', '_')
            task_label = int(label[task_id].item())
            if any([((task in f) and (f'v-{view_id}' in f))
                    for f in os.listdir(directory)]) or task_label != 1:
                continue

            probs, idx, cam = grad_cam.get_cam(input_, task_id, task_name)

            # Resize cam and overlay on image
            resized_cam = cv2.resize(cam, (height, width))
            # We don't normalize since the grad clam class has already taken care of that
            img_with_cam = util.add_heat_map(img, resized_cam, normalize=False)

            # Save a cam for this task and image
            # using task, prob and groundtruth in file name
            prob = probs[idx == task_id].item()
            if view_id is None:
                filename = f'{task}-p{prob:.3f}-gt{task_label}.png'
            else:
                filename = f'{task}-p{prob:.3f}-gt{task_label}-v-{view_id}.png'
            output_path = os.path.join(directory, filename)
            imsave(output_path, img_with_cam)

    # Save the original image in the same folder
    output_path = os.path.join(directory, f'original_image-v-{view_id}.png')
    img = np.uint8(img * 255)
    imsave(output_path, img)
def get_cams(args):
    print('Loading model...')
    model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
    model = model.to(args.device)
    args.start_epoch = ckpt_info['epoch'] + 1

    print('Last layer in model.features is named "{}"...'.format([k for k in model.module.encoders._modules.keys()][-1]))
    print('Extracting feature maps from layer named "{}"...'.format(args.target_layer))

    grad_cam = GradCAM(model, args.device, is_binary=True, is_3d=True)
    print(grad_cam)
    gbp = GuidedBackPropagation(model, args.device, is_binary=True, is_3d=True)
    print(gbp)

    num_generated = 0
    data_loader = CTDataLoader(args, phase=args.phase, is_training=False)
    print(data_loader)
    study_idx_dict = {}
    study_count = 1
    for inputs, target_dict in data_loader:
        #print(inputs, target_dict)
        #print('target_dict dir={}'.format(dir(target_dict)))
        #print('\ntarget_dict[study_num]={}'.format(target_dict['study_num']))
        probs, idx = grad_cam.forward(inputs)
        grad_cam.backward(idx=idx[0])  # Just take top prediction
        cam = grad_cam.get_cam(args.target_layer)

        labels = target_dict['is_abnormal']
        if labels.item() == 0:
            # Keep going until we get an aneurysm study
            print('Skipping a normal example...')
            continue
        

        print('Generating CAM...')
        study_num = 1
        with torch.set_grad_enabled(True):
            probs, idx = grad_cam.forward(inputs)
            print(probs, idx)
            grad_cam.backward(idx=idx[0])  # Just take top prediction
            cam = grad_cam.get_cam(args.target_layer)

            guided_backprop = None
            if args.use_gbp:
                inputs2 = torch.autograd.Variable(inputs, requires_grad=True)
                probs2, idx2 = gbp.forward(inputs2)
                gbp.backward(idx=idx2[0])
                guided_backprop = np.squeeze(gbp.generate())

        print('Overlaying CAM...')
        print(cam.shape)
        new_cam = util.resize(cam, inputs[0])
        print(new_cam.shape)


        input_np = util.un_normalize(inputs[0], args.img_format, data_loader.dataset.pixel_dict)
        input_np = np.transpose(input_np, (1, 2, 3, 0))
        input_frames = list(input_np)

        input_normed = np.float32(input_np) / 255
        cam_frames = list(util.add_heat_map(input_normed, new_cam))

        gbp_frames = None
        if args.use_gbp:
            gbp_np = util.normalize_to_image(guided_backprop * new_cam)
            gbp_frames = []
            for dim in range(gbp_np.shape[0]):
                slice_ = gbp_np[dim, :, :]
                gbp_frames.append(slice_[..., None])

        # Write to a GIF file
        output_path_input = os.path.join(os.path.join(args.cam_dir, '{}_{}_input_fn.gif'.format(target_dict['study_num'], study_count)))
        output_path_cam = os.path.join(args.cam_dir, '{}_{}_cam_fn.gif'.format(target_dict['study_num'], study_count))
        output_path_combined = os.path.join(args.cam_dir, '{}_{}_combined_fn.gif'.format(target_dict['study_num'], study_count))

        print('Writing set {}/{} of CAMs to {}...'.format(num_generated + 1, args.num_cams, args.cam_dir))


        input_clip = mpy.ImageSequenceClip(input_frames, fps=4)
        input_clip.write_gif(output_path_input, verbose=False)
        cam_clip = mpy.ImageSequenceClip(cam_frames, fps=4)
        cam_clip.write_gif(output_path_cam, verbose=False)
        combined_clip = mpy.clips_array([[input_clip, cam_clip]])
        combined_clip.write_gif(output_path_combined, verbose=False)

        if args.use_gbp:
            output_path_gcam = os.path.join(args.cam_dir, 'gbp_{}.gif'.format(num_generated + 1))
            gbp_clip = mpy.ImageSequenceClip(gbp_frames, fps=4)
            gbp_clip.write_gif(output_path_gcam, verbose=False)

        study_count += 1
        num_generated += 1
        if num_generated == args.num_cams:
            return