Exemplo n.º 1
0
    def forward_with_decisions(self, outputs):
        outputs = self.forward(outputs)
        _, predicted = outputs.max(1)

        decisions = []
        node = self.nodes[0]
        leaf_to_path_nodes = Node.get_leaf_to_path(self.nodes)
        for index, prediction in enumerate(predicted):
            leaf = node.wnids[prediction]
            decision = leaf_to_path_nodes[leaf]
            for justification in decision:
                justification['prob'] = -1  # TODO(alvin): fill in prob
            decisions.append(decision)
        return outputs, decisions
    def forward_with_decisions(self, outputs):
        outputs_ = outputs[:, 1:]
        outputs_ = self.forward(outputs_)
        _, predicted = outputs_.max(1)

        decisions = []
        node = self.nodes[0]
        leaf_to_path_nodes = Node.get_leaf_to_path(self.nodes)
        for index, prediction in enumerate(predicted):
            leaf = node.wnids[prediction]
            decision = leaf_to_path_nodes[leaf]
            for justification in decision:
                justification['prob'] = -1  # TODO(alvin): fill in prob
            decisions.append(decision)
        outputs_score = F.softmax(outputs, dim=-1)
        outputs = torch.cat((outputs_score[:, 0].unsqueeze(1), outputs_),
                            dim=1)
        return outputs, decisions
Exemplo n.º 3
0
def main():
    args = parse_args()

    logger, final_output_dir, _ = create_logger(config, args.cfg,
                                                'vis_gradcam')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    # build model
    model = eval('models.' + config.MODEL.NAME + '.get_seg_model')(config)

    dump_input = torch.rand(
        (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]))
    logger.info(get_model_summary(model.cuda(), dump_input.cuda()))

    if config.TEST.MODEL_FILE:
        model_state_file = config.TEST.MODEL_FILE
    else:
        model_state_file = os.path.join(final_output_dir, 'best.pth')
    logger.info('=> loading model from {}'.format(model_state_file))

    pretrained_dict = torch.load(model_state_file)
    model_dict = model.state_dict()
    pretrained_dict = {
        k[6:]: v
        for k, v in pretrained_dict.items() if k[6:] in model_dict.keys()
    }
    for k, _ in pretrained_dict.items():
        logger.info('=> loading {} from pretrained model'.format(k))
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    # Wrap original model with NBDT
    if config.NBDT.USE_NBDT:
        from nbdt.model import SoftSegNBDT
        model = SoftSegNBDT(config.NBDT.DATASET,
                            model,
                            hierarchy=config.NBDT.HIERARCHY,
                            classes=class_names)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device).eval()

    # Retrieve input image corresponding to args.image_index
    test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
    test_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TEST_SET,
        num_samples=None,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=False,
        flip=False,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TEST.BASE_SIZE,
        crop_size=test_size,
        downsample_rate=1)

    # Define target layer as final convolution layer if not specified
    if args.target_layers:
        target_layers = args.target_layers.split(',')
    else:
        for name, module in list(model.named_modules())[::-1]:
            if isinstance(module, nn.Conv2d):
                target_layers = [name]
            break
    logger.info('Target layers set to {}'.format(str(target_layers)))

    # Append model. to target layers if using nbdt
    if config.NBDT.USE_NBDT:
        target_layers = ['model.' + layer for layer in target_layers]

    def generate_and_save_saliency(image_index,
                                   pixel_i=None,
                                   pixel_j=None,
                                   crop_size=None,
                                   normalize=False):
        """too lazy to move out to global lol"""
        nonlocal maximum, minimum, label
        # Generate GradCAM + save heatmap
        heatmaps = []
        raw_image = retrieve_raw_image(test_dataset, image_index)

        should_crop = crop_size is not None and pixel_i is not None and pixel_j is not None
        if should_crop:
            raw_image = crop(pixel_i,
                             pixel_j,
                             crop_size,
                             raw_image,
                             is_tensor=False)

        for layer in target_layers:
            gradcam_region = gradcam.generate(target_layer=layer,
                                              normalize=False)

            if should_crop:
                gradcam_region = crop(pixel_i,
                                      pixel_j,
                                      crop_size,
                                      gradcam_region,
                                      is_tensor=True)

            maximum = max(float(gradcam_region.max()), maximum)
            minimum = min(float(gradcam_region.min()), minimum)
            logger.info(f'=> Bounds: ({minimum}, {maximum})')

            heatmaps.append(gradcam_region)
            output_dir = generate_output_dir(final_output_dir, args.vis_mode,
                                             layer, config.NBDT.USE_NBDT,
                                             nbdt_node_wnid, args.crop_size,
                                             args.nbdt_node_wnids_for)
            save_path = generate_save_path(output_dir, gradcam_kwargs)
            logger.info('Saving {} heatmap at {}...'.format(
                args.vis_mode, save_path))

            if normalize:
                gradcam_region = GradCAM.normalize(gradcam_region)
                save_gradcam(save_path,
                             gradcam_region,
                             raw_image,
                             save_npy=not args.skip_save_npy)
            else:
                save_gradcam(save_path,
                             gradcam_region,
                             raw_image,
                             minimum=minimum,
                             maximum=maximum,
                             save_npy=not args.skip_save_npy)

            output_dir_original = output_dir + '_original'
            os.makedirs(output_dir_original, exist_ok=True)
            save_path_original = generate_save_path(output_dir_original,
                                                    gradcam_kwargs,
                                                    ext='jpg')
            logger.info('Saving {} original at {}...'.format(
                args.vis_mode, save_path_original))
            cv2.imwrite(save_path_original, raw_image)

            if crop_size and pixel_i and pixel_j:
                continue
            output_dir += '_overlap'
            os.makedirs(output_dir, exist_ok=True)
            save_path_overlap = generate_save_path(output_dir,
                                                   gradcam_kwargs,
                                                   ext='npy')
            save_path_plot = generate_save_path(output_dir,
                                                gradcam_kwargs,
                                                ext='jpg')
            if not args.skip_save_npy:
                logger.info('Saving {} overlap data at {}...'.format(
                    args.vis_mode, save_path_overlap))
            logger.info('Saving {} overlap plot at {}...'.format(
                args.vis_mode, save_path_plot))
            save_overlap(save_path_overlap,
                         save_path_plot,
                         gradcam_region,
                         label,
                         save_npy=not args.skip_save_npy)
        if len(heatmaps) > 1:
            combined = torch.prod(torch.stack(heatmaps, dim=0), dim=0)
            combined /= combined.max()
            save_path = generate_save_path(final_output_dir, args.vis_mode,
                                           gradcam_kwargs, 'combined',
                                           config.NBDT.USE_NBDT,
                                           nbdt_node_wnid)
            logger.info('Saving combined {} heatmap at {}...'.format(
                args.vis_mode, save_path))
            save_gradcam(save_path, combined, raw_image)

    nbdt_node_wnids = args.nbdt_node_wnid or []
    cls = args.nbdt_node_wnids_for
    if cls:
        assert config.NBDT.USE_NBDT, 'Must be using NBDT'
        from nbdt.data.custom import Node
        assert hasattr(model, 'rules') and hasattr(model.rules, 'nodes'), \
            'NBDT must have rules with nodes'
        logger.info("Getting nodes leading up to class leaf {}...".format(cls))
        leaf_to_path_nodes = Node.get_leaf_to_path(model.rules.nodes)

        cls_index = class_names.index(cls)
        leaf = model.rules.nodes[0].wnids[cls_index]
        path_nodes = leaf_to_path_nodes[leaf]
        nbdt_node_wnids = [
            item['node'].wnid for item in path_nodes if item['node']
        ]

    def run():
        nonlocal maximum, minimum, label, gradcam_kwargs
        for image_index in get_image_indices(args.image_index,
                                             args.image_index_range):
            image, label, _, name = test_dataset[image_index]
            image = torch.from_numpy(image).unsqueeze(0).to(device)
            logger.info("Using image {}...".format(name))
            pred_probs, pred_labels = gradcam.forward(image)

            maximum, minimum = -1000, 0
            logger.info(f'=> Starting bounds: ({minimum}, {maximum})')

            if args.crop_for and class_names.index(args.crop_for) not in label:
                print(
                    f'Skipping image {image_index} because no {args.crop_for} found'
                )
                continue

            if getattr(Saliency, 'whole_image', False):
                assert not (
                        args.pixel_i or args.pixel_j or args.pixel_i_range
                        or args.pixel_j_range), \
                    'the "Whole" saliency method generates one map for the whole ' \
                    'image, not for specific pixels'
                gradcam_kwargs = {'image': image_index}
                if args.suffix:
                    gradcam_kwargs['suffix'] = args.suffix
                gradcam.backward(pred_labels[:, [0], :, :])

                generate_and_save_saliency(image_index)

                if args.crop_size <= 0:
                    continue

            if args.crop_for:
                cls_index = class_names.index(args.crop_for)
                label = torch.Tensor(label).to(pred_labels.device)
                # is_right_class = pred_labels[0,0,:,:] == cls_index
                # is_correct = pred_labels == label
                is_right_class = is_correct = label == cls_index  #TODO:tmp
                pixels = (is_right_class * is_correct).nonzero()

                pixels = get_random_pixels(args.pixel_max_num_random,
                                           pixels,
                                           seed=cls_index)
            else:
                assert (args.pixel_i
                        or args.pixel_i_range) and (args.pixel_j
                                                    or args.pixel_j_range)
                pixels = get_pixels(args.pixel_i, args.pixel_j,
                                    args.pixel_i_range, args.pixel_j_range,
                                    args.pixel_cartesian_product)
            logger.info(f'Running on {len(pixels)} pixels.')

            for pixel_i, pixel_j in pixels:
                pixel_i, pixel_j = int(pixel_i), int(pixel_j)
                assert pixel_i < test_size[0] and pixel_j < test_size[1], \
                    "Pixel ({},{}) is out of bounds for image of size ({},{})".format(
                        pixel_i,pixel_j,test_size[0],test_size[1])

                # Run backward pass
                # Note: Computes backprop wrt most likely predicted class rather than gt class
                gradcam_kwargs = {
                    'image': image_index,
                    'pixel_i': pixel_i,
                    'pixel_j': pixel_j
                }
                if args.suffix:
                    gradcam_kwargs['suffix'] = args.suffix
                logger.info(
                    f'Running {args.vis_mode} on image {image_index} at pixel ({pixel_i},{pixel_j}). Using filename suffix: {args.suffix}'
                )
                output_pixel_i, output_pixel_j = compute_output_coord(
                    pixel_i, pixel_j, test_size, pred_probs.shape[2:])

                if not getattr(Saliency, 'whole_image', False):
                    gradcam.backward(pred_labels[:, [0], :, :], output_pixel_i,
                                     output_pixel_j)

                if args.crop_size <= 0:
                    generate_and_save_saliency(image_index)
                else:
                    generate_and_save_saliency(image_index, pixel_i, pixel_j,
                                               args.crop_size)

            logger.info(f'=> Final bounds are: ({minimum}, {maximum})')

    # Instantiate wrapper once, outside of loop
    Saliency = METHODS[args.vis_mode]
    gradcam = Saliency(model=model,
                       candidate_layers=target_layers,
                       use_nbdt=config.NBDT.USE_NBDT,
                       nbdt_node_wnid=None)

    maximum, minimum, label, gradcam_kwargs = -1000, 0, None, {}
    for nbdt_node_wnid in nbdt_node_wnids:
        if config.NBDT.USE_NBDT:
            logger.info("Using logits from node with wnid {}...".format(
                nbdt_node_wnid))
        gradcam.set_nbdt_node_wnid(nbdt_node_wnid)
        run()

    if not nbdt_node_wnids:
        nbdt_node_wnid = None
        run()
Exemplo n.º 4
0
def main():
    random.seed(11)
    args = parse_args()

    logger, final_output_dir, _ = create_logger(config, args.cfg,
                                                'vis_gradcam')

    # logger.info(pprint.pformat(args))
    # logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    # build model
    model = eval('models.' + config.MODEL.NAME + '.get_seg_model')(config)

    dump_input = torch.rand(
        (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]))
    # logger.info(get_model_summary(model.cuda(), dump_input.cuda()))

    if config.TEST.MODEL_FILE:
        model_state_file = config.TEST.MODEL_FILE
    else:
        model_state_file = os.path.join(final_output_dir, 'best.pth')
    model_state_file = 'pretrained_models/hrnet_w18_small_model_v1.pth'
    logger.info('=> loading model from {}'.format(model_state_file))
    # __import__('ipdb').set_trace()
    pretrained_dict = torch.load(model_state_file)
    model_dict = model.state_dict()
    pretrained_dict = {
        k[6:]: v
        for k, v in pretrained_dict.items() if k[6:] in model_dict.keys()
    }
    # for k, _ in pretrained_dict.items():
    #     logger.info(
    #         '=> loading {} from pretrained model'.format(k))
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    # Wrap original model with NBDT
    if config.NBDT.USE_NBDT:
        from nbdt.model import SoftSegNBDT
        model = SoftSegNBDT(config.NBDT.DATASET,
                            model,
                            hierarchy=config.NBDT.HIERARCHY,
                            classes=class_names)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # device = 'cpu'
    model = model.to(device).eval()
    # for x in model.named_modules():
    #     print(x)
    # Retrieve input image corresponding to args.image_index
    test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
    test_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TEST_SET,
        num_samples=None,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=False,
        flip=False,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TEST.BASE_SIZE,
        crop_size=test_size,
        downsample_rate=1)

    # Define target layer as final convolution layer if not specified
    if args.target_layers:
        target_layers = args.target_layers.split(',')
    else:
        for name, module in list(model.named_modules())[::-1]:
            if isinstance(module, nn.Conv2d):
                target_layers = [name]
            break
    logger.info('Target layers set to {}'.format(str(target_layers)))

    # Append model. to target layers if using nbdt
    if config.NBDT.USE_NBDT:
        target_layers = ['model.' + layer for layer in target_layers]

    dmg_by_cls = [[] for _ in class_names]

    def generate_and_save_saliency(image_index,
                                   pixel_i=None,
                                   pixel_j=None,
                                   crop_size=None,
                                   cls_index=0,
                                   normalize=False,
                                   previous_pred=1):
        """too lazy to move out to global lol"""
        nonlocal maximum, minimum, label
        # Generate GradCAM + save heatmap
        # raw_image = retrieve_raw_image(test_dataset, image_index)

        # should_crop = crop_size is not None and pixel_i is not None and pixel_j is not None
        # if should_crop:
        #     raw_image = crop(pixel_i, pixel_j, crop_size, raw_image, is_tensor=False)

        for layer in target_layers:
            gradcam_region = gradcam.generate(target_layer=layer,
                                              normalize=False)

            # if should_crop:
            #     gradcam_region = crop(pixel_i, pixel_j, crop_size, gradcam_region, is_tensor=True)

            maximum = max(float(gradcam_region.max()), maximum)
            minimum = min(float(gradcam_region.min()), minimum)
            logger.info(f'=> Bounds: ({minimum}, {maximum})')

            # print('start')
            image = test_dataset[image_index][0]
            # coords = gradcam_region.nonzero().cpu().numpy()

            # __import__('ipdb').set_trace()
            # centered_coords = coords - np.array([0, 0, pixel_i, pixel_j])
            # centered_coords = np.abs(centered_coords)
            # # Filter function for determining whether pixel is within receptive field
            # f = lambda x: x[2] < 200 and x[3] < 200
            # within_receptive_field = np.apply_along_axis(f, 1, centered_coords)
            # coords = coords[within_receptive_field]
            # gradcam_region = gradcam_region.cpu()
            # def v(x):
            #     return gradcam_region[x[0], x[1], x[2], x[3]]
            # within_weight = np.apply_along_axis(v, 1, coords).sum()

            # count = 0
            # inside_mass = 0
            # total_mass = 0
            # for coord in foo:
            #     total_mass += gradcam_region[coord[0], coord[1], coord[2], coord[3]]
            #     image[:, coord[2], coord[3]] = torch.zeros(1, 3)
            #     # 400px receptive field
            #     if abs(coord[2] - pixel_i) < 200 and abs(coord[3] - pixel_j) < 200:
            #         count += 1
            #         inside_mass += gradcam_region[coord[0], coord[1], coord[2], coord[3]]

            # percent_mass_inside = within_weight / gradcam_region.sum()
            # percent_mass_outside = 1 - percent_mass_inside
            # percent_pixels_inside = np.sum(within_receptive_field) / len(within_receptive_field)
            # percent_pixels_outside = 1 - percent_pixels_inside

            # pct_mass_inside_list.append(percent_mass_inside)
            # pct_pixels_inside_list.append(percent_pixels_inside)

            image = torch.from_numpy(image).unsqueeze(0).to(
                gradcam_region.device)

            # Zero out pixels marked by the saliency map
            coords_to_zero = gradcam_region != 0
            image[:, 0:1, :, :][coords_to_zero] = 0
            image[:, 1:2, :, :][coords_to_zero] = 0
            image[:, 2:3, :, :][coords_to_zero] = 0

            pred_probs, pred_labels = gradcam.forward(image)
            # New location of the target class after occlusion
            new_index = np.where(
                pred_labels[0, :, pixel_i,
                            pixel_j].cpu().numpy() == cls_index)[0][0]
            damage = previous_pred - pred_probs[0, new_index, pixel_i,
                                                pixel_j].item()
            # dmg_by_cls[cls_index].append(damage)
            dmg_by_cls[cls_index].append(
                (damage, image_index, [pixel_i, pixel_j]))
            del image
            del pred_probs
            del pred_labels
            torch.cuda.empty_cache()

            # bar = pred_probs[:,:,pixel_i,pixel_j]
            # baz = pred_labels[:,:,pixel_i,pixel_j]

            # print("percent_mass_inside ", percent_mass_inside)
            # print("percent_pixels_inside ", percent_pixels_inside)

            # __import__('ipdb').set_trace()
            # image = test_dataset[image_index][0]
            # image = torch.from_numpy(image).unsqueeze(0).to(gradcam_region.device)
            # pred_probs, pred_labels = gradcam.forward(image)
            # ordered_class_names = [class_names[i] for i in pred_labels[0,:,180,1000]]
            # target_index = ordered_class_names.index('building')
            # images = image

            # pred_labels[:, [target_index], :, :]
            # occlusion_map = occlusion_sensitivity(model, images, gradcam)

            # output_dir = generate_output_dir(final_output_dir, args.vis_mode, layer, config.NBDT.USE_NBDT, nbdt_node_wnid, args.crop_size, args.nbdt_node_wnids_for)
            # save_path = generate_save_path(output_dir, gradcam_kwargs)
            # logger.info('Saving {} heatmap at {}...'.format(args.vis_mode, save_path))

            # if normalize:
            #     gradcam_region = GradCAM.normalize(gradcam_region)
            #     save_gradcam(save_path, gradcam_region, raw_image, save_npy=not args.skip_save_npy)
            # else:
            #     save_gradcam(save_path, gradcam_region, raw_image, minimum=minimum, maximum=maximum, save_npy=not args.skip_save_npy)

            # output_dir_original = output_dir + '_original'
            # os.makedirs(output_dir_original, exist_ok=True)
            # save_path_original = generate_save_path(output_dir_original, gradcam_kwargs, ext='jpg')
            # logger.info('Saving {} original at {}...'.format(args.vis_mode, save_path_original))
            # cv2.imwrite(save_path_original, raw_image)

        #     if crop_size and pixel_i and pixel_j:
        #         continue
        #     output_dir += '_overlap'
        #     os.makedirs(output_dir, exist_ok=True)
        #     save_path_overlap = generate_save_path(output_dir, gradcam_kwargs, ext='npy')
        #     save_path_plot = generate_save_path(output_dir, gradcam_kwargs, ext='jpg')
        #     if not args.skip_save_npy:
        #         logger.info('Saving {} overlap data at {}...'.format(args.vis_mode, save_path_overlap))
        #     logger.info('Saving {} overlap plot at {}...'.format(args.vis_mode, save_path_plot))
        #     save_overlap(save_path_overlap, save_path_plot, gradcam_region, label, save_npy=not args.skip_save_npy)
        # if len(heatmaps) > 1:
        #     combined = torch.prod(torch.stack(heatmaps, dim=0), dim=0)
        #     combined /= combined.max()
        #     save_path = generate_save_path(final_output_dir, args.vis_mode, gradcam_kwargs, 'combined', config.NBDT.USE_NBDT, nbdt_node_wnid)
        #     logger.info('Saving combined {} heatmap at {}...'.format(args.vis_mode, save_path))
        #     save_gradcam(save_path, combined, raw_image)

    nbdt_node_wnids = args.nbdt_node_wnid or []
    cls = args.nbdt_node_wnids_for
    if cls:
        assert config.NBDT.USE_NBDT, 'Must be using NBDT'
        from nbdt.data.custom import Node
        assert hasattr(model, 'rules') and hasattr(model.rules, 'nodes'), \
            'NBDT must have rules with nodes'
        logger.info("Getting nodes leading up to class leaf {}...".format(cls))
        leaf_to_path_nodes = Node.get_leaf_to_path(model.rules.nodes)

        cls_index = class_names.index(cls)
        leaf = model.rules.nodes[0].wnids[cls_index]
        path_nodes = leaf_to_path_nodes[leaf]
        nbdt_node_wnids = [
            item['node'].wnid for item in path_nodes if item['node']
        ]

    def run():
        nonlocal maximum, minimum, label, gradcam_kwargs
        np.random.seed(13)
        for image_index in get_image_indices(args.image_index,
                                             args.image_index_range):
            image, label, _, name = test_dataset[image_index]
            image = torch.from_numpy(image).unsqueeze(0).to(device)
            logger.info("Using image {}...".format(name))

            pred_probs, pred_labels = gradcam.forward(image)

            maximum, minimum = -1000, 0
            logger.info(f'=> Starting bounds: ({minimum}, {maximum})')

            if args.crop_for and class_names.index(args.crop_for) not in label:
                print(
                    f'Skipping image {image_index} because no {args.crop_for} found'
                )
                continue

            if getattr(Saliency, 'whole_image', False):
                assert not (
                        args.pixel_i or args.pixel_j or args.pixel_i_range
                        or args.pixel_j_range), \
                    'the "Whole" saliency method generates one map for the whole ' \
                    'image, not for specific pixels'
                gradcam_kwargs = {'image': image_index}
                if args.suffix:
                    gradcam_kwargs['suffix'] = args.suffix
                gradcam.backward(pred_labels[:, [0], :, :])

                generate_and_save_saliency(image_index)

                if args.crop_size <= 0:
                    continue

            if args.crop_for:
                cls_index = class_names.index(args.crop_for)
                label = torch.Tensor(label).to(pred_labels.device)
                # is_right_class = pred_labels[0,0,:,:] == cls_index
                # is_correct = pred_labels == label
                is_right_class = is_correct = label == cls_index  #TODO:tmp
                pixels = (is_right_class * is_correct).nonzero()

                pixels = get_random_pixels(args.pixel_max_num_random,
                                           pixels,
                                           seed=cls_index)
            else:
                assert (args.pixel_i
                        or args.pixel_i_range) and (args.pixel_j
                                                    or args.pixel_j_range)
                pixels = get_pixels(args.pixel_i, args.pixel_j,
                                    args.pixel_i_range, args.pixel_j_range,
                                    args.pixel_cartesian_product)
            logger.info(f'Running on {len(pixels)} pixels.')

            # for pixel_index in range(len(pixels)):
            # for cls_index in range(len(class_names)):
            for _ in range(1):
                cls_index = 17  # hardcode person
                # pixel_i, pixel_j = int(pixels[pixel_index][0]), int(pixels[pixel_index][1])
                # pixel_i, pixel_j = int(args.pixel_i[0]), int(args.pixel_j[0])
                # list of all coords where the label matches the current class
                # index.
                matching_coords = np.random.permutation(
                    np.array(np.where(label == cls_index)).T)
                if matching_coords.shape[0] == 0:
                    continue
                pixel_i, pixel_j = 0, 0
                found_matching = False
                for coord in matching_coords:
                    pixel_i, pixel_j = coord[0], coord[1]
                    if pred_labels[0, 0, pixel_i, pixel_j].item() == cls_index:
                        found_matching = True
                        break

                if not found_matching:
                    continue
                assert pred_labels[0, 0, pixel_i, pixel_j].item() == cls_index

                # pixel_i, pixel_j = ps[pixel_index]

                # top_class_index = pred_labels[0,0,pixel_i,pixel_j]
                # top_class_name = class_names[top_class_index]
                # print("Running on pixel ({}, {})".format(pixel_i, pixel_j))
                assert pixel_i < test_size[0] and pixel_j < test_size[1], \
                    "Pixel ({},{}) is out of bounds for image of size ({},{})".format(
                        pixel_i,pixel_j,test_size[0],test_size[1])

                # Get the current prediction confidence for the top class.
                # Will use to compute occlusion damage later.
                curr_pred = pred_probs[0, 0, pixel_i, pixel_j].item()
                # Run backward pass
                # Note: Computes backprop wrt most likely predicted class rather than gt class

                # gradcam_kwargs = {'image': image_index, 'pixel_i': pixel_i, 'pixel_j': pixel_j, 'class_name': top_class_name}
                # if args.suffix:
                #     gradcam_kwargs['suffix'] = args.suffix
                # logger.info(f'Running {args.vis_mode} on image {image_index} at pixel ({pixel_i},{pixel_j}). Using filename suffix: {args.suffix}')
                output_pixel_i, output_pixel_j = compute_output_coord(
                    pixel_i, pixel_j, test_size, pred_probs.shape[2:])
                target_index = 0
                if not getattr(Saliency, 'whole_image', False):
                    gradcam.backward(pred_labels[:, [target_index], :, :],
                                     output_pixel_i, output_pixel_j)

                # gradcam_region = gradcam.generate(target_layer=target_layers[0], normalize=False).cpu()
                # __import__('ipdb').set_trace()

                # now `gradcam_region` and `second_region` contain both saliency maps for two pixels of the
                # same class but at least 800px apart.

                # ordered_class_names = [class_names[i] for i in pred_labels[0,:,pixel_i,pixel_j]]
                # target_index = ordered_class_names.index('car')
                # print("TOP CLASS", pred_labels[0, 0, pixel_i, pixel_j], pred_probs[0,0,pixel_i,pixel_j], class_names[pred_labels[0, 0, pixel_i, pixel_j]])
                # print("######################################")

                if args.crop_size <= 0:
                    generate_and_save_saliency(image_index,
                                               pixel_i,
                                               pixel_j,
                                               cls_index=cls_index,
                                               previous_pred=curr_pred)
                else:
                    generate_and_save_saliency(image_index,
                                               pixel_i,
                                               pixel_j,
                                               args.crop_size,
                                               cls_index=cls_index,
                                               previous_pred=curr_pred)

            # for i in range(len(dmg_by_cls)):
            #     if len(dmg_by_cls[i]) > 0 and i == 16:
            #         # print(i, np.mean(dmg_by_cls[i]))
            #         # logger.info('{}, {}'.format(i, np.mean(dmg_by_cls[i])))
            #         logger.info('{}, {}'.format(i, dmg_by_cls[i]))

            # logger.info(f'=> Final bounds are: ({minimum}, {maximum})')

    # Instantiate wrapper once, outside of loop
    Saliency = METHODS[args.vis_mode]
    gradcam = Saliency(model=model,
                       candidate_layers=target_layers,
                       use_nbdt=config.NBDT.USE_NBDT,
                       nbdt_node_wnid=None)

    maximum, minimum, label, gradcam_kwargs = -1000, 0, None, {}
    for nbdt_node_wnid in nbdt_node_wnids:
        if config.NBDT.USE_NBDT:
            logger.info("Using logits from node with wnid {}...".format(
                nbdt_node_wnid))
        gradcam.set_nbdt_node_wnid(nbdt_node_wnid)
        run()

    if not nbdt_node_wnids:
        nbdt_node_wnid = None
        run()
        # for i in range(len(dmg_by_cls)):
        #     if len(dmg_by_cls[i]) > 0:
        #         logger.info('{}, {}'.format(i, np.mean(dmg_by_cls[i])))
        for i in range(len(dmg_by_cls)):
            if len(dmg_by_cls[i]) > 0:
                # print(i, np.mean(dmg_by_cls[i]))
                # logger.info('{}, {}'.format(i, np.mean(dmg_by_cls[i])))
                logger.info('{}, {}'.format(i, dmg_by_cls[i]))
        # __import__('ipdb').set_trace()
        print('done')