def init(self,
         dataset,
         criterion,
         sample_nums,
         path_graph,
         path_wnids,
         classes,
         Rules,
         tree_supervision_weight=1.):
     """
     Extra init method makes clear which arguments are finally necessary for
     this class to function. The constructor for this class may generate
     some of these required arguments if initially missing.
     """
     self.dataset = dataset
     self.num_classes = len(classes)
     self.nodes = Node.get_nodes(path_graph, path_wnids, classes)
     self.rules = Rules(dataset, path_graph, path_wnids, classes)
     self.tree_supervision_weight = tree_supervision_weight
     self.criterion = criterion
     self.sample_nums = np.array(sample_nums)
     self.node_depths = defaultdict(lambda: [])
     self.node_weights = defaultdict(lambda: [])
     effective_num = 1.0 - np.power(0.999, self.sample_nums)
     weights = (1.0 - 0.999) / np.array(effective_num)
     self.weights = weights
     for node in self.nodes:
         key = node.num_classes
         depth = node.get_depth()
         self.node_depths[key].append(depth)
         node_weight = []
         for new_label in range(node.num_classes):
             node_weight.append(weights[node.new_to_old_classes[new_label]])
         self.node_weights[key].append(node_weight)
    def __init__(self, dataset, path_graph=None, path_wnids=None, classes=()):

        if not path_graph:
            path_graph = dataset_to_default_path_graph(dataset)
        if not path_wnids:
            path_wnids = dataset_to_default_path_wnids(dataset)
        if not classes:
            classes = dataset_to_dummy_classes(dataset)
        super().__init__()
        assert all([dataset, path_graph, path_wnids, classes])

        self.classes = classes

        self.nodes = Node.get_nodes(path_graph, path_wnids, classes)
        self.G = self.nodes[0].G
        self.wnid_to_node = {node.wnid: node for node in self.nodes}

        self.wnids = get_wnids(path_wnids)
        self.wnid_to_class = {
            wnid: cls
            for wnid, cls in zip(self.wnids, self.classes)
        }

        self.correct = 0
        self.total = 0

        self.I = torch.eye(len(classes))
Exemplo n.º 3
0
    def forward_with_decisions(self, outputs):
        outputs = self.forward(outputs)
        _, predicted = outputs.max(1)

        decisions = []
        node = self.nodes[0]
        leaf_to_path_nodes = Node.get_leaf_to_path(self.nodes)
        for index, prediction in enumerate(predicted):
            leaf = node.wnids[prediction]
            decision = leaf_to_path_nodes[leaf]
            for justification in decision:
                justification['prob'] = -1  # TODO(alvin): fill in prob
            decisions.append(decision)
        return outputs, decisions
    def forward_with_decisions(self, outputs):
        outputs_ = outputs[:, 1:]
        outputs_ = self.forward(outputs_)
        _, predicted = outputs_.max(1)

        decisions = []
        node = self.nodes[0]
        leaf_to_path_nodes = Node.get_leaf_to_path(self.nodes)
        for index, prediction in enumerate(predicted):
            leaf = node.wnids[prediction]
            decision = leaf_to_path_nodes[leaf]
            for justification in decision:
                justification['prob'] = -1  # TODO(alvin): fill in prob
            decisions.append(decision)
        outputs_score = F.softmax(outputs, dim=-1)
        outputs = torch.cat((outputs_score[:, 0].unsqueeze(1), outputs_),
                            dim=1)
        return outputs, decisions
Exemplo n.º 5
0
 def init(self,
          dataset,
          criterion,
          path_graph,
          path_wnids,
          classes,
          Rules,
          tree_supervision_weight=1.):
     """
     Extra init method makes clear which arguments are finally necessary for
     this class to function. The constructor for this class may generate
     some of these required arguments if initially missing.
     """
     self.dataset = dataset
     self.num_classes = len(classes)
     self.nodes = Node.get_nodes(path_graph, path_wnids, classes)
     self.rules = Rules(dataset, path_graph, path_wnids, classes)
     self.tree_supervision_weight = tree_supervision_weight
     self.criterion = criterion
Exemplo n.º 6
0
def main():
    args = parse_args()

    logger, final_output_dir, _ = create_logger(config, args.cfg,
                                                'vis_gradcam')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    # build model
    model = eval('models.' + config.MODEL.NAME + '.get_seg_model')(config)

    dump_input = torch.rand(
        (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]))
    logger.info(get_model_summary(model.cuda(), dump_input.cuda()))

    if config.TEST.MODEL_FILE:
        model_state_file = config.TEST.MODEL_FILE
    else:
        model_state_file = os.path.join(final_output_dir, 'best.pth')
    logger.info('=> loading model from {}'.format(model_state_file))

    pretrained_dict = torch.load(model_state_file)
    model_dict = model.state_dict()
    pretrained_dict = {
        k[6:]: v
        for k, v in pretrained_dict.items() if k[6:] in model_dict.keys()
    }
    for k, _ in pretrained_dict.items():
        logger.info('=> loading {} from pretrained model'.format(k))
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    # Wrap original model with NBDT
    if config.NBDT.USE_NBDT:
        from nbdt.model import SoftSegNBDT
        model = SoftSegNBDT(config.NBDT.DATASET,
                            model,
                            hierarchy=config.NBDT.HIERARCHY,
                            classes=class_names)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device).eval()

    # Retrieve input image corresponding to args.image_index
    test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
    test_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TEST_SET,
        num_samples=None,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=False,
        flip=False,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TEST.BASE_SIZE,
        crop_size=test_size,
        downsample_rate=1)

    # Define target layer as final convolution layer if not specified
    if args.target_layers:
        target_layers = args.target_layers.split(',')
    else:
        for name, module in list(model.named_modules())[::-1]:
            if isinstance(module, nn.Conv2d):
                target_layers = [name]
            break
    logger.info('Target layers set to {}'.format(str(target_layers)))

    # Append model. to target layers if using nbdt
    if config.NBDT.USE_NBDT:
        target_layers = ['model.' + layer for layer in target_layers]

    def generate_and_save_saliency(image_index,
                                   pixel_i=None,
                                   pixel_j=None,
                                   crop_size=None,
                                   normalize=False):
        """too lazy to move out to global lol"""
        nonlocal maximum, minimum, label
        # Generate GradCAM + save heatmap
        heatmaps = []
        raw_image = retrieve_raw_image(test_dataset, image_index)

        should_crop = crop_size is not None and pixel_i is not None and pixel_j is not None
        if should_crop:
            raw_image = crop(pixel_i,
                             pixel_j,
                             crop_size,
                             raw_image,
                             is_tensor=False)

        for layer in target_layers:
            gradcam_region = gradcam.generate(target_layer=layer,
                                              normalize=False)

            if should_crop:
                gradcam_region = crop(pixel_i,
                                      pixel_j,
                                      crop_size,
                                      gradcam_region,
                                      is_tensor=True)

            maximum = max(float(gradcam_region.max()), maximum)
            minimum = min(float(gradcam_region.min()), minimum)
            logger.info(f'=> Bounds: ({minimum}, {maximum})')

            heatmaps.append(gradcam_region)
            output_dir = generate_output_dir(final_output_dir, args.vis_mode,
                                             layer, config.NBDT.USE_NBDT,
                                             nbdt_node_wnid, args.crop_size,
                                             args.nbdt_node_wnids_for)
            save_path = generate_save_path(output_dir, gradcam_kwargs)
            logger.info('Saving {} heatmap at {}...'.format(
                args.vis_mode, save_path))

            if normalize:
                gradcam_region = GradCAM.normalize(gradcam_region)
                save_gradcam(save_path,
                             gradcam_region,
                             raw_image,
                             save_npy=not args.skip_save_npy)
            else:
                save_gradcam(save_path,
                             gradcam_region,
                             raw_image,
                             minimum=minimum,
                             maximum=maximum,
                             save_npy=not args.skip_save_npy)

            output_dir_original = output_dir + '_original'
            os.makedirs(output_dir_original, exist_ok=True)
            save_path_original = generate_save_path(output_dir_original,
                                                    gradcam_kwargs,
                                                    ext='jpg')
            logger.info('Saving {} original at {}...'.format(
                args.vis_mode, save_path_original))
            cv2.imwrite(save_path_original, raw_image)

            if crop_size and pixel_i and pixel_j:
                continue
            output_dir += '_overlap'
            os.makedirs(output_dir, exist_ok=True)
            save_path_overlap = generate_save_path(output_dir,
                                                   gradcam_kwargs,
                                                   ext='npy')
            save_path_plot = generate_save_path(output_dir,
                                                gradcam_kwargs,
                                                ext='jpg')
            if not args.skip_save_npy:
                logger.info('Saving {} overlap data at {}...'.format(
                    args.vis_mode, save_path_overlap))
            logger.info('Saving {} overlap plot at {}...'.format(
                args.vis_mode, save_path_plot))
            save_overlap(save_path_overlap,
                         save_path_plot,
                         gradcam_region,
                         label,
                         save_npy=not args.skip_save_npy)
        if len(heatmaps) > 1:
            combined = torch.prod(torch.stack(heatmaps, dim=0), dim=0)
            combined /= combined.max()
            save_path = generate_save_path(final_output_dir, args.vis_mode,
                                           gradcam_kwargs, 'combined',
                                           config.NBDT.USE_NBDT,
                                           nbdt_node_wnid)
            logger.info('Saving combined {} heatmap at {}...'.format(
                args.vis_mode, save_path))
            save_gradcam(save_path, combined, raw_image)

    nbdt_node_wnids = args.nbdt_node_wnid or []
    cls = args.nbdt_node_wnids_for
    if cls:
        assert config.NBDT.USE_NBDT, 'Must be using NBDT'
        from nbdt.data.custom import Node
        assert hasattr(model, 'rules') and hasattr(model.rules, 'nodes'), \
            'NBDT must have rules with nodes'
        logger.info("Getting nodes leading up to class leaf {}...".format(cls))
        leaf_to_path_nodes = Node.get_leaf_to_path(model.rules.nodes)

        cls_index = class_names.index(cls)
        leaf = model.rules.nodes[0].wnids[cls_index]
        path_nodes = leaf_to_path_nodes[leaf]
        nbdt_node_wnids = [
            item['node'].wnid for item in path_nodes if item['node']
        ]

    def run():
        nonlocal maximum, minimum, label, gradcam_kwargs
        for image_index in get_image_indices(args.image_index,
                                             args.image_index_range):
            image, label, _, name = test_dataset[image_index]
            image = torch.from_numpy(image).unsqueeze(0).to(device)
            logger.info("Using image {}...".format(name))
            pred_probs, pred_labels = gradcam.forward(image)

            maximum, minimum = -1000, 0
            logger.info(f'=> Starting bounds: ({minimum}, {maximum})')

            if args.crop_for and class_names.index(args.crop_for) not in label:
                print(
                    f'Skipping image {image_index} because no {args.crop_for} found'
                )
                continue

            if getattr(Saliency, 'whole_image', False):
                assert not (
                        args.pixel_i or args.pixel_j or args.pixel_i_range
                        or args.pixel_j_range), \
                    'the "Whole" saliency method generates one map for the whole ' \
                    'image, not for specific pixels'
                gradcam_kwargs = {'image': image_index}
                if args.suffix:
                    gradcam_kwargs['suffix'] = args.suffix
                gradcam.backward(pred_labels[:, [0], :, :])

                generate_and_save_saliency(image_index)

                if args.crop_size <= 0:
                    continue

            if args.crop_for:
                cls_index = class_names.index(args.crop_for)
                label = torch.Tensor(label).to(pred_labels.device)
                # is_right_class = pred_labels[0,0,:,:] == cls_index
                # is_correct = pred_labels == label
                is_right_class = is_correct = label == cls_index  #TODO:tmp
                pixels = (is_right_class * is_correct).nonzero()

                pixels = get_random_pixels(args.pixel_max_num_random,
                                           pixels,
                                           seed=cls_index)
            else:
                assert (args.pixel_i
                        or args.pixel_i_range) and (args.pixel_j
                                                    or args.pixel_j_range)
                pixels = get_pixels(args.pixel_i, args.pixel_j,
                                    args.pixel_i_range, args.pixel_j_range,
                                    args.pixel_cartesian_product)
            logger.info(f'Running on {len(pixels)} pixels.')

            for pixel_i, pixel_j in pixels:
                pixel_i, pixel_j = int(pixel_i), int(pixel_j)
                assert pixel_i < test_size[0] and pixel_j < test_size[1], \
                    "Pixel ({},{}) is out of bounds for image of size ({},{})".format(
                        pixel_i,pixel_j,test_size[0],test_size[1])

                # Run backward pass
                # Note: Computes backprop wrt most likely predicted class rather than gt class
                gradcam_kwargs = {
                    'image': image_index,
                    'pixel_i': pixel_i,
                    'pixel_j': pixel_j
                }
                if args.suffix:
                    gradcam_kwargs['suffix'] = args.suffix
                logger.info(
                    f'Running {args.vis_mode} on image {image_index} at pixel ({pixel_i},{pixel_j}). Using filename suffix: {args.suffix}'
                )
                output_pixel_i, output_pixel_j = compute_output_coord(
                    pixel_i, pixel_j, test_size, pred_probs.shape[2:])

                if not getattr(Saliency, 'whole_image', False):
                    gradcam.backward(pred_labels[:, [0], :, :], output_pixel_i,
                                     output_pixel_j)

                if args.crop_size <= 0:
                    generate_and_save_saliency(image_index)
                else:
                    generate_and_save_saliency(image_index, pixel_i, pixel_j,
                                               args.crop_size)

            logger.info(f'=> Final bounds are: ({minimum}, {maximum})')

    # Instantiate wrapper once, outside of loop
    Saliency = METHODS[args.vis_mode]
    gradcam = Saliency(model=model,
                       candidate_layers=target_layers,
                       use_nbdt=config.NBDT.USE_NBDT,
                       nbdt_node_wnid=None)

    maximum, minimum, label, gradcam_kwargs = -1000, 0, None, {}
    for nbdt_node_wnid in nbdt_node_wnids:
        if config.NBDT.USE_NBDT:
            logger.info("Using logits from node with wnid {}...".format(
                nbdt_node_wnid))
        gradcam.set_nbdt_node_wnid(nbdt_node_wnid)
        run()

    if not nbdt_node_wnids:
        nbdt_node_wnid = None
        run()
Exemplo n.º 7
0
def main():
    random.seed(11)
    args = parse_args()

    logger, final_output_dir, _ = create_logger(config, args.cfg,
                                                'vis_gradcam')

    # logger.info(pprint.pformat(args))
    # logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    # build model
    model = eval('models.' + config.MODEL.NAME + '.get_seg_model')(config)

    dump_input = torch.rand(
        (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]))
    # logger.info(get_model_summary(model.cuda(), dump_input.cuda()))

    if config.TEST.MODEL_FILE:
        model_state_file = config.TEST.MODEL_FILE
    else:
        model_state_file = os.path.join(final_output_dir, 'best.pth')
    model_state_file = 'pretrained_models/hrnet_w18_small_model_v1.pth'
    logger.info('=> loading model from {}'.format(model_state_file))
    # __import__('ipdb').set_trace()
    pretrained_dict = torch.load(model_state_file)
    model_dict = model.state_dict()
    pretrained_dict = {
        k[6:]: v
        for k, v in pretrained_dict.items() if k[6:] in model_dict.keys()
    }
    # for k, _ in pretrained_dict.items():
    #     logger.info(
    #         '=> loading {} from pretrained model'.format(k))
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    # Wrap original model with NBDT
    if config.NBDT.USE_NBDT:
        from nbdt.model import SoftSegNBDT
        model = SoftSegNBDT(config.NBDT.DATASET,
                            model,
                            hierarchy=config.NBDT.HIERARCHY,
                            classes=class_names)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # device = 'cpu'
    model = model.to(device).eval()
    # for x in model.named_modules():
    #     print(x)
    # Retrieve input image corresponding to args.image_index
    test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
    test_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TEST_SET,
        num_samples=None,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=False,
        flip=False,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TEST.BASE_SIZE,
        crop_size=test_size,
        downsample_rate=1)

    # Define target layer as final convolution layer if not specified
    if args.target_layers:
        target_layers = args.target_layers.split(',')
    else:
        for name, module in list(model.named_modules())[::-1]:
            if isinstance(module, nn.Conv2d):
                target_layers = [name]
            break
    logger.info('Target layers set to {}'.format(str(target_layers)))

    # Append model. to target layers if using nbdt
    if config.NBDT.USE_NBDT:
        target_layers = ['model.' + layer for layer in target_layers]

    dmg_by_cls = [[] for _ in class_names]

    def generate_and_save_saliency(image_index,
                                   pixel_i=None,
                                   pixel_j=None,
                                   crop_size=None,
                                   cls_index=0,
                                   normalize=False,
                                   previous_pred=1):
        """too lazy to move out to global lol"""
        nonlocal maximum, minimum, label
        # Generate GradCAM + save heatmap
        # raw_image = retrieve_raw_image(test_dataset, image_index)

        # should_crop = crop_size is not None and pixel_i is not None and pixel_j is not None
        # if should_crop:
        #     raw_image = crop(pixel_i, pixel_j, crop_size, raw_image, is_tensor=False)

        for layer in target_layers:
            gradcam_region = gradcam.generate(target_layer=layer,
                                              normalize=False)

            # if should_crop:
            #     gradcam_region = crop(pixel_i, pixel_j, crop_size, gradcam_region, is_tensor=True)

            maximum = max(float(gradcam_region.max()), maximum)
            minimum = min(float(gradcam_region.min()), minimum)
            logger.info(f'=> Bounds: ({minimum}, {maximum})')

            # print('start')
            image = test_dataset[image_index][0]
            # coords = gradcam_region.nonzero().cpu().numpy()

            # __import__('ipdb').set_trace()
            # centered_coords = coords - np.array([0, 0, pixel_i, pixel_j])
            # centered_coords = np.abs(centered_coords)
            # # Filter function for determining whether pixel is within receptive field
            # f = lambda x: x[2] < 200 and x[3] < 200
            # within_receptive_field = np.apply_along_axis(f, 1, centered_coords)
            # coords = coords[within_receptive_field]
            # gradcam_region = gradcam_region.cpu()
            # def v(x):
            #     return gradcam_region[x[0], x[1], x[2], x[3]]
            # within_weight = np.apply_along_axis(v, 1, coords).sum()

            # count = 0
            # inside_mass = 0
            # total_mass = 0
            # for coord in foo:
            #     total_mass += gradcam_region[coord[0], coord[1], coord[2], coord[3]]
            #     image[:, coord[2], coord[3]] = torch.zeros(1, 3)
            #     # 400px receptive field
            #     if abs(coord[2] - pixel_i) < 200 and abs(coord[3] - pixel_j) < 200:
            #         count += 1
            #         inside_mass += gradcam_region[coord[0], coord[1], coord[2], coord[3]]

            # percent_mass_inside = within_weight / gradcam_region.sum()
            # percent_mass_outside = 1 - percent_mass_inside
            # percent_pixels_inside = np.sum(within_receptive_field) / len(within_receptive_field)
            # percent_pixels_outside = 1 - percent_pixels_inside

            # pct_mass_inside_list.append(percent_mass_inside)
            # pct_pixels_inside_list.append(percent_pixels_inside)

            image = torch.from_numpy(image).unsqueeze(0).to(
                gradcam_region.device)

            # Zero out pixels marked by the saliency map
            coords_to_zero = gradcam_region != 0
            image[:, 0:1, :, :][coords_to_zero] = 0
            image[:, 1:2, :, :][coords_to_zero] = 0
            image[:, 2:3, :, :][coords_to_zero] = 0

            pred_probs, pred_labels = gradcam.forward(image)
            # New location of the target class after occlusion
            new_index = np.where(
                pred_labels[0, :, pixel_i,
                            pixel_j].cpu().numpy() == cls_index)[0][0]
            damage = previous_pred - pred_probs[0, new_index, pixel_i,
                                                pixel_j].item()
            # dmg_by_cls[cls_index].append(damage)
            dmg_by_cls[cls_index].append(
                (damage, image_index, [pixel_i, pixel_j]))
            del image
            del pred_probs
            del pred_labels
            torch.cuda.empty_cache()

            # bar = pred_probs[:,:,pixel_i,pixel_j]
            # baz = pred_labels[:,:,pixel_i,pixel_j]

            # print("percent_mass_inside ", percent_mass_inside)
            # print("percent_pixels_inside ", percent_pixels_inside)

            # __import__('ipdb').set_trace()
            # image = test_dataset[image_index][0]
            # image = torch.from_numpy(image).unsqueeze(0).to(gradcam_region.device)
            # pred_probs, pred_labels = gradcam.forward(image)
            # ordered_class_names = [class_names[i] for i in pred_labels[0,:,180,1000]]
            # target_index = ordered_class_names.index('building')
            # images = image

            # pred_labels[:, [target_index], :, :]
            # occlusion_map = occlusion_sensitivity(model, images, gradcam)

            # output_dir = generate_output_dir(final_output_dir, args.vis_mode, layer, config.NBDT.USE_NBDT, nbdt_node_wnid, args.crop_size, args.nbdt_node_wnids_for)
            # save_path = generate_save_path(output_dir, gradcam_kwargs)
            # logger.info('Saving {} heatmap at {}...'.format(args.vis_mode, save_path))

            # if normalize:
            #     gradcam_region = GradCAM.normalize(gradcam_region)
            #     save_gradcam(save_path, gradcam_region, raw_image, save_npy=not args.skip_save_npy)
            # else:
            #     save_gradcam(save_path, gradcam_region, raw_image, minimum=minimum, maximum=maximum, save_npy=not args.skip_save_npy)

            # output_dir_original = output_dir + '_original'
            # os.makedirs(output_dir_original, exist_ok=True)
            # save_path_original = generate_save_path(output_dir_original, gradcam_kwargs, ext='jpg')
            # logger.info('Saving {} original at {}...'.format(args.vis_mode, save_path_original))
            # cv2.imwrite(save_path_original, raw_image)

        #     if crop_size and pixel_i and pixel_j:
        #         continue
        #     output_dir += '_overlap'
        #     os.makedirs(output_dir, exist_ok=True)
        #     save_path_overlap = generate_save_path(output_dir, gradcam_kwargs, ext='npy')
        #     save_path_plot = generate_save_path(output_dir, gradcam_kwargs, ext='jpg')
        #     if not args.skip_save_npy:
        #         logger.info('Saving {} overlap data at {}...'.format(args.vis_mode, save_path_overlap))
        #     logger.info('Saving {} overlap plot at {}...'.format(args.vis_mode, save_path_plot))
        #     save_overlap(save_path_overlap, save_path_plot, gradcam_region, label, save_npy=not args.skip_save_npy)
        # if len(heatmaps) > 1:
        #     combined = torch.prod(torch.stack(heatmaps, dim=0), dim=0)
        #     combined /= combined.max()
        #     save_path = generate_save_path(final_output_dir, args.vis_mode, gradcam_kwargs, 'combined', config.NBDT.USE_NBDT, nbdt_node_wnid)
        #     logger.info('Saving combined {} heatmap at {}...'.format(args.vis_mode, save_path))
        #     save_gradcam(save_path, combined, raw_image)

    nbdt_node_wnids = args.nbdt_node_wnid or []
    cls = args.nbdt_node_wnids_for
    if cls:
        assert config.NBDT.USE_NBDT, 'Must be using NBDT'
        from nbdt.data.custom import Node
        assert hasattr(model, 'rules') and hasattr(model.rules, 'nodes'), \
            'NBDT must have rules with nodes'
        logger.info("Getting nodes leading up to class leaf {}...".format(cls))
        leaf_to_path_nodes = Node.get_leaf_to_path(model.rules.nodes)

        cls_index = class_names.index(cls)
        leaf = model.rules.nodes[0].wnids[cls_index]
        path_nodes = leaf_to_path_nodes[leaf]
        nbdt_node_wnids = [
            item['node'].wnid for item in path_nodes if item['node']
        ]

    def run():
        nonlocal maximum, minimum, label, gradcam_kwargs
        np.random.seed(13)
        for image_index in get_image_indices(args.image_index,
                                             args.image_index_range):
            image, label, _, name = test_dataset[image_index]
            image = torch.from_numpy(image).unsqueeze(0).to(device)
            logger.info("Using image {}...".format(name))

            pred_probs, pred_labels = gradcam.forward(image)

            maximum, minimum = -1000, 0
            logger.info(f'=> Starting bounds: ({minimum}, {maximum})')

            if args.crop_for and class_names.index(args.crop_for) not in label:
                print(
                    f'Skipping image {image_index} because no {args.crop_for} found'
                )
                continue

            if getattr(Saliency, 'whole_image', False):
                assert not (
                        args.pixel_i or args.pixel_j or args.pixel_i_range
                        or args.pixel_j_range), \
                    'the "Whole" saliency method generates one map for the whole ' \
                    'image, not for specific pixels'
                gradcam_kwargs = {'image': image_index}
                if args.suffix:
                    gradcam_kwargs['suffix'] = args.suffix
                gradcam.backward(pred_labels[:, [0], :, :])

                generate_and_save_saliency(image_index)

                if args.crop_size <= 0:
                    continue

            if args.crop_for:
                cls_index = class_names.index(args.crop_for)
                label = torch.Tensor(label).to(pred_labels.device)
                # is_right_class = pred_labels[0,0,:,:] == cls_index
                # is_correct = pred_labels == label
                is_right_class = is_correct = label == cls_index  #TODO:tmp
                pixels = (is_right_class * is_correct).nonzero()

                pixels = get_random_pixels(args.pixel_max_num_random,
                                           pixels,
                                           seed=cls_index)
            else:
                assert (args.pixel_i
                        or args.pixel_i_range) and (args.pixel_j
                                                    or args.pixel_j_range)
                pixels = get_pixels(args.pixel_i, args.pixel_j,
                                    args.pixel_i_range, args.pixel_j_range,
                                    args.pixel_cartesian_product)
            logger.info(f'Running on {len(pixels)} pixels.')

            # for pixel_index in range(len(pixels)):
            # for cls_index in range(len(class_names)):
            for _ in range(1):
                cls_index = 17  # hardcode person
                # pixel_i, pixel_j = int(pixels[pixel_index][0]), int(pixels[pixel_index][1])
                # pixel_i, pixel_j = int(args.pixel_i[0]), int(args.pixel_j[0])
                # list of all coords where the label matches the current class
                # index.
                matching_coords = np.random.permutation(
                    np.array(np.where(label == cls_index)).T)
                if matching_coords.shape[0] == 0:
                    continue
                pixel_i, pixel_j = 0, 0
                found_matching = False
                for coord in matching_coords:
                    pixel_i, pixel_j = coord[0], coord[1]
                    if pred_labels[0, 0, pixel_i, pixel_j].item() == cls_index:
                        found_matching = True
                        break

                if not found_matching:
                    continue
                assert pred_labels[0, 0, pixel_i, pixel_j].item() == cls_index

                # pixel_i, pixel_j = ps[pixel_index]

                # top_class_index = pred_labels[0,0,pixel_i,pixel_j]
                # top_class_name = class_names[top_class_index]
                # print("Running on pixel ({}, {})".format(pixel_i, pixel_j))
                assert pixel_i < test_size[0] and pixel_j < test_size[1], \
                    "Pixel ({},{}) is out of bounds for image of size ({},{})".format(
                        pixel_i,pixel_j,test_size[0],test_size[1])

                # Get the current prediction confidence for the top class.
                # Will use to compute occlusion damage later.
                curr_pred = pred_probs[0, 0, pixel_i, pixel_j].item()
                # Run backward pass
                # Note: Computes backprop wrt most likely predicted class rather than gt class

                # gradcam_kwargs = {'image': image_index, 'pixel_i': pixel_i, 'pixel_j': pixel_j, 'class_name': top_class_name}
                # if args.suffix:
                #     gradcam_kwargs['suffix'] = args.suffix
                # logger.info(f'Running {args.vis_mode} on image {image_index} at pixel ({pixel_i},{pixel_j}). Using filename suffix: {args.suffix}')
                output_pixel_i, output_pixel_j = compute_output_coord(
                    pixel_i, pixel_j, test_size, pred_probs.shape[2:])
                target_index = 0
                if not getattr(Saliency, 'whole_image', False):
                    gradcam.backward(pred_labels[:, [target_index], :, :],
                                     output_pixel_i, output_pixel_j)

                # gradcam_region = gradcam.generate(target_layer=target_layers[0], normalize=False).cpu()
                # __import__('ipdb').set_trace()

                # now `gradcam_region` and `second_region` contain both saliency maps for two pixels of the
                # same class but at least 800px apart.

                # ordered_class_names = [class_names[i] for i in pred_labels[0,:,pixel_i,pixel_j]]
                # target_index = ordered_class_names.index('car')
                # print("TOP CLASS", pred_labels[0, 0, pixel_i, pixel_j], pred_probs[0,0,pixel_i,pixel_j], class_names[pred_labels[0, 0, pixel_i, pixel_j]])
                # print("######################################")

                if args.crop_size <= 0:
                    generate_and_save_saliency(image_index,
                                               pixel_i,
                                               pixel_j,
                                               cls_index=cls_index,
                                               previous_pred=curr_pred)
                else:
                    generate_and_save_saliency(image_index,
                                               pixel_i,
                                               pixel_j,
                                               args.crop_size,
                                               cls_index=cls_index,
                                               previous_pred=curr_pred)

            # for i in range(len(dmg_by_cls)):
            #     if len(dmg_by_cls[i]) > 0 and i == 16:
            #         # print(i, np.mean(dmg_by_cls[i]))
            #         # logger.info('{}, {}'.format(i, np.mean(dmg_by_cls[i])))
            #         logger.info('{}, {}'.format(i, dmg_by_cls[i]))

            # logger.info(f'=> Final bounds are: ({minimum}, {maximum})')

    # Instantiate wrapper once, outside of loop
    Saliency = METHODS[args.vis_mode]
    gradcam = Saliency(model=model,
                       candidate_layers=target_layers,
                       use_nbdt=config.NBDT.USE_NBDT,
                       nbdt_node_wnid=None)

    maximum, minimum, label, gradcam_kwargs = -1000, 0, None, {}
    for nbdt_node_wnid in nbdt_node_wnids:
        if config.NBDT.USE_NBDT:
            logger.info("Using logits from node with wnid {}...".format(
                nbdt_node_wnid))
        gradcam.set_nbdt_node_wnid(nbdt_node_wnid)
        run()

    if not nbdt_node_wnids:
        nbdt_node_wnid = None
        run()
        # for i in range(len(dmg_by_cls)):
        #     if len(dmg_by_cls[i]) > 0:
        #         logger.info('{}, {}'.format(i, np.mean(dmg_by_cls[i])))
        for i in range(len(dmg_by_cls)):
            if len(dmg_by_cls[i]) > 0:
                # print(i, np.mean(dmg_by_cls[i]))
                # logger.info('{}, {}'.format(i, np.mean(dmg_by_cls[i])))
                logger.info('{}, {}'.format(i, dmg_by_cls[i]))
        # __import__('ipdb').set_trace()
        print('done')