def train_affinitynet(config: Config):
    config_json = config.toDictionary()
    print('train_affinitynet')
    print(config_json)
    from training.train import train
    from torch.utils.data.dataloader import DataLoader
    from data.loader_segmentation import Segmentation
    from artifacts.artifact_manager import artifact_manager

    model = get_model(config.affinity_net_name)
    
    wandb.init(entity='kobus_wits', project='wass_affinity', name=config.sweep_id + '_a_' + config.affinity_net_name, config=config_json)
    wandb.watch(model)

    train(
        model=model,
        dataloaders = {
            'train': DataLoader(
                Segmentation(
                    config.classifier_dataset_root,
                    source='train',
                    augmentation='train',
                    image_size=config.affinity_net_image_size,
                    requested_labels=['affinity'],
                    affinity_root=artifact_manager.getDir()
                ),
                batch_size=config.affinity_net_batch_size,
                shuffle=False,
                pin_memory=False,
                num_workers=4,
                prefetch_factor=4
            ),
        },
        epochs=config.affinity_net_epochs,
        validation_mod=10
    )

    wandb.finish()
def save_semseg(
    dataset_root,
    model_name,
    batch_size=8,
    image_size=256,
    use_gt_labels=False,
):
    print('Save semseg : ', locals())
    import shutil
    import cv2
    import os
    import numpy as np
    import torch.nn.functional as F
    import torch
    from models.get_model import get_model
    from torch.utils.data.dataloader import DataLoader
    from data.loader_segmentation import Segmentation
    from artifacts.artifact_manager import artifact_manager
    from data.voc2012 import label_to_image

    # Set up model
    model = get_model(model_name)
    model.load()
    model.to(model.device)
    model.train(False)

    # Set up data loader
    dataloader = DataLoader(
        Segmentation(dataset_root,
                     source='val',
                     source_augmentation='val',
                     image_size=image_size,
                     requested_labels=['classification', 'segmentation']),
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        prefetch_factor=4,
    )

    # Clear and create destination directory
    semseg_path = os.path.join(artifact_manager.getDir(), 'semseg_output')
    if (os.path.exists(semseg_path)):
        shutil.rmtree(semseg_path)
    os.makedirs(semseg_path)

    for batch_no, batch in enumerate(dataloader):
        inputs_in = batch[0]
        labels_in = batch[1]
        datapacket_in = batch[2]

        # Run images through model and get raw cams
        with torch.no_grad():
            semsegs = model.event({
                'name': 'get_semseg',
                'inputs': inputs_in,
                'labels': labels_in,
                'batch': batch_no + 1
            })

            semsegs = semsegs.detach().cpu().numpy()

        # Save out cams
        for semseg_no, semseg in enumerate(semsegs):
            # Save out ground truth labels for testing the rest of the system
            if use_gt_labels:
                semseg = labels_in['segmentation'][semseg_no][1:]
                semseg = F.adaptive_avg_pool2d(semseg, [32, 32]).numpy()

                for i in range(0, semseg.shape[0]):
                    semseg[i] = cv2.blur(semseg[i], (3, 3))
                    semseg[i] = cv2.blur(semseg[i], (3, 3))

            # # Disregard false positives
            # gt_mask = labels_in['classification'][semseg_no].numpy()
            # gt_mask[gt_mask > 0.5] = 1
            # gt_mask[gt_mask <= 0.5] = 0
            # gt_mask = np.expand_dims(np.expand_dims(gt_mask, -1), -1)
            # cam *= gt_mask

            # Upsample CAM to original image size
            # - Calculate original image aspect ratio
            width = datapacket_in['width'][semseg_no].detach().numpy()
            height = datapacket_in['height'][semseg_no].detach().numpy()
            aspect_ratio = width / height

            # - Calculate width and height to cut from upscaled CAM
            if aspect_ratio > 1:
                cut_width = image_size
                cut_height = round(image_size / aspect_ratio)
            else:
                cut_width = round(image_size * aspect_ratio)
                cut_height = image_size

            # - Upscale CAM to match input size
            semseg = np.moveaxis(semseg, 0, -1)
            semseg = cv2.resize(semseg, (image_size, image_size),
                                interpolation=cv2.INTER_LINEAR)
            semseg = np.moveaxis(semseg, -1, 0)

            # - Cut CAM from input size and upscale to original image size
            semseg = semseg[:, 0:cut_height, 0:cut_width]
            semseg = np.moveaxis(semseg, 0, -1)
            semseg = cv2.resize(semseg, (width, height),
                                interpolation=cv2.INTER_LINEAR)
            semseg = np.moveaxis(semseg, -1, 0)

            semseg_as_label = label_to_image(semseg)

            # Write image
            img_no = datapacket_in['image_name'][semseg_no]
            cv2.imwrite(
                os.path.join(semseg_path, img_no) + '.png',
                semseg_as_label * 255)
            print('Save cam : ', img_no, end='\r')
    print('')
def save_cams_crf(config: Config):
    config_json = config.toDictionary()
    print('save_cams_crf')
    print(config_json)
    import shutil
    import numpy as np
    from torch.utils.data.dataloader import DataLoader
    from data.loader_segmentation import Segmentation
    from data.voc2012 import label_to_image
    from multiprocessing import Pool
    from artifacts.artifact_manager import artifact_manager

    cam_root_path = os.path.join(artifact_manager.getDir(), 'cam')

    # Set up data loader
    dataloader = DataLoader(Segmentation(
        config.classifier_dataset_root,
        source='train',
        image_size=config.classifier_image_size),
                            batch_size=32,
                            shuffle=False,
                            num_workers=2,
                            prefetch_factor=2)

    # Create high, low dirs
    cam_la_path = os.path.join(artifact_manager.getDir(), 'cam_la')
    if (os.path.exists(cam_la_path)):
        shutil.rmtree(cam_la_path)
    os.makedirs(cam_la_path)

    cam_ha_path = os.path.join(artifact_manager.getDir(), 'cam_ha')
    if (os.path.exists(cam_ha_path)):
        shutil.rmtree(cam_ha_path)
    os.makedirs(cam_ha_path)

    wandb.init(entity='kobus_wits',
               project='wass_measure_cams_crfs',
               name=config.sweep_id + '_cam_' + config.classifier_name,
               config=config_json)
    count = 0

    for batch_no, batch in enumerate(dataloader):
        from training.save_cams_crf import _process_sample

        labels = batch[1]
        datapacket = batch[2]

        payloads = []
        for image_no, image_name in enumerate(datapacket['image_name']):
            payload = {
                'image_name': image_name,
                'count': count,
                'image_width': datapacket['width'][image_no].numpy(),
                'image_height': datapacket['height'][image_no].numpy(),
                'channels': labels['classification'].shape[1],
                'image_path': datapacket['image_path'][image_no],
                'cam_la_path': cam_la_path,
                'cam_ha_path': cam_ha_path,
                'cam_root_path': cam_root_path,
                'alpha_low': config.cams_bg_alpha_low,
                'alpha_high': config.cams_bg_alpha_high,
            }
            payloads.append(payload)
            count += 1
            print('Save cam : ', count, end='\r')

        with Pool(8) as poel:
            logs = poel.map(_process_sample, payloads)

            for log in logs:
                wandb.log(log, step=log['image_count'])

    print('')
    wandb.finish()
def measure_random_walk(config: Config):
    config_json = config.toDictionary()
    print('measure_cams')
    print(config_json)
    import os

    from torch.utils.data.dataloader import DataLoader
    from data.loader_segmentation import Segmentation
    from artifacts.artifact_manager import artifact_manager
    from multiprocessing import Pool

    # Set up data loader
    dataloader = DataLoader(Segmentation(
        config.eval_dataset_root,
        source='train',
        augmentation='val',
        image_size=config.classifier_image_size,
        requested_labels=['classification', 'segmentation']),
                            batch_size=config.affinity_net_batch_size,
                            shuffle=False,
                            num_workers=2,
                            prefetch_factor=2)

    # Get cams directory
    labels_rw_root_path = os.path.join(artifact_manager.getDir(), 'labels_rw')

    count = 0

    wandb.init(entity='kobus_wits',
               project='wass_measure_cams_rw',
               name=config.sweep_id + '_cam_' + config.classifier_name,
               config=config_json)
    avg_meter = AverageMeter('accuracy', 'mapr', 'miou')

    for batch_no, batch in enumerate(dataloader):
        datapacket_in = batch[2]

        payloads = []
        logs = []
        for image_no, image_name in enumerate(datapacket_in['image_name']):
            payload = {
                'count':
                count,
                'label_path':
                datapacket_in['label_path'][image_no],
                'predi_path':
                os.path.join(labels_rw_root_path, image_name + '.png'),
            }
            payloads.append(payload)
            logs.append(_measure_sample(payload))
            count += 1
            print('Measure cam RW : ', count, end='\r')

        # with Pool(8) as poel:
        # logs = poel.map(_measure_sample, payloads)

        for log in logs:
            avg_meter.add({
                'accuracy': log['accuracy'],
                'mapr': log['mapr'],
                'miou': log['miou'],
            })

            if log['count'] < 8:
                wandb.log(log, step=log['count'])

        wandb.log({
            'accuracy': avg_meter.get('accuracy'),
            'mapr': avg_meter.get('mapr'),
            'miou': avg_meter.get('miou'),
        })

    wandb.finish()
def save_cams_random_walk(config: Config):
    config_json = config.toDictionary()
    print('save_cams_random_walk')
    print(config_json)
    import shutil
    import os
    from torch.utils.data.dataloader import DataLoader
    from data.loader_segmentation import Segmentation
    from artifacts.artifact_manager import artifact_manager

    # Set up model
    model = get_model(config.affinity_net_name)
    model.load()
    model.eval()
    model.to(model.device)

    # Set up data loader
    dataloader = DataLoader(Segmentation(
        config.classifier_dataset_root,
        source='train',
        augmentation='affinity_predict',
        image_size=config.affinity_net_image_size,
        requested_labels=['classification', 'segmentation']),
                            batch_size=1,
                            shuffle=False,
                            num_workers=2,
                            prefetch_factor=2)

    # Get cam source directory
    cam_path = os.path.join(artifact_manager.getDir(), 'cam')

    # Clear and create output directory
    labels_rw_path = os.path.join(artifact_manager.getDir(), 'labels_rw')
    if (os.path.exists(labels_rw_path)):
        shutil.rmtree(labels_rw_path)
    os.makedirs(labels_rw_path)

    count = 0

    for batch_no, batch in enumerate(dataloader):
        inputs = batch[0]
        labels = batch[1]
        datapacket = batch[2]

        for image_no, image_name in enumerate(datapacket['image_name']):
            image = inputs['image'].cuda(non_blocking=True)
            image_width = datapacket['width'][image_no].numpy()
            image_height = datapacket['height'][image_no].numpy()
            channels = labels['classification'].shape[1]

            # Pad image
            image_width_padded = int(np.ceil(image_width / 8) * 8)
            image_height_padded = int(np.ceil(image_height / 8) * 8)
            image_padded = F.pad(image,
                                 (0, image_width_padded - image_width, 0,
                                  image_height_padded - image_height))

            image_width_pooled = int(np.ceil(image_width_padded / 8))
            image_height_pooled = int(np.ceil(image_height_padded / 8))

            # Load cam
            cam_path_instance = os.path.join(cam_path, image_name + '.png')
            cam = cv2.imread(cam_path_instance, cv2.IMREAD_GRAYSCALE)
            cam = np.reshape(cam, ((channels, image_height, image_width)))
            cam = cam / 255.0

            # Build cam background
            cam_background = (
                1 - np.max(cam,
                           (0), keepdims=True))**config.affinity_net_bg_alpha
            cam = np.concatenate((cam_background, cam), axis=0)
            cam = cam.astype(np.float32)

            # Pad cam
            cam_padded_width = int(np.ceil(cam.shape[2] / 8) * 8)
            cam_padded_height = int(np.ceil(cam.shape[1] / 8) * 8)
            cam_padded = np.pad(cam,
                                ((0, 0), (0, cam_padded_height - image_height),
                                 (0, cam_padded_width - image_width)),
                                mode='constant')

            # Run images through model and get affinity matrix
            with torch.no_grad():
                aff_mat = model.event({
                    'name': 'infer_aff_net_dense',
                    'image': image_padded,
                })
                aff_mat = torch.pow(aff_mat, config.affinity_net_beta)

            trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True)
            for _ in range(config.affinity_net_log_t):
                trans_mat = torch.matmul(trans_mat, trans_mat)

            cam_pooled = F.avg_pool2d(torch.from_numpy(cam_padded), 8, 8)

            cam_vec = cam_pooled.view(21, -1)

            cam_rw = torch.matmul(cam_vec.cuda(), trans_mat)
            cam_rw = cam_rw.view(1, 21, image_height_pooled,
                                 image_width_pooled)

            cam_rw = torch.nn.Upsample(
                (image_height_padded, image_width_padded),
                mode='bilinear')(cam_rw)
            cam_rw = cam_rw.cpu().data[0, :, :image_height, :image_width]

            label_rw = label_to_image(cam_rw)

            cv2.imwrite(os.path.join(labels_rw_path, image_name + '.png'),
                        label_rw * 255)

            count += 1
            print('Save cam : ', count, end='\r')

    print('')
def visualize(model, dataloader, output_dir, max_count=1_000_000):
    model.train()
    model.to(device)

    folder_name = artifact_manager.getDir() + output_dir
    visualize_model(model, dataloader, folder_name, max_count)
    def event(self, event):
        super().event(event)

        if event['name'] == 'minibatch' and event['phase'] == 'train':
            image = event['inputs']['image']
            label_c = event['labels']['classification']
            label_s = event['labels']['segmentation']

            # Run input through adversary
            adversary_result = self.adversary(image, label_c)
            iou_label = label_s.clone().detach().cpu().numpy()
            iou_predi = adversary_result['segmentation'].clone().detach().cpu(
            ).numpy()
            max_indices = iou_predi.max(axis=1, keepdims=True) == iou_predi

            iou_predi = np.zeros(iou_predi.shape)
            iou_predi[max_indices] = 1
            score_iou = iou(iou_label[:, 1:], iou_predi[:, 1:])

            # Training controller
            self.step_count += 1
            if self.step == 'classifier':
                self.step_count_c += 1
                if self.step_count == 3:
                    self.step = 'adversary'
                    self.step_count = 0

                # Train classifier
                if random.random() > 0.5:
                    classification = self.classifier(image)
                else:
                    classification = self.classifier(adversary_result['erase'])
                loss_c_bce = self.classifier.loss_bce(classification, label_c)
                loss_c_bce.backward()
                self.classifier.optimizer.step()

                wandb.log({
                    "step_count_c": self.step_count_c,
                    "loss_c_bce": loss_c_bce,
                    "score_iou": score_iou
                })

            elif self.step == 'adversary':
                self.step_count_a += 1
                if self.step_count == 3:
                    self.step = 'classifier'
                    self.step_count = 0

                # Channel loss
                loss_a_channel = self.adversary.loss_bce(
                    adversary_result['classification'], label_c)

                # Constrain loss
                loss_a_mask = torch.mean(adversary_result['mask'])

                # Classifier loss
                c_spot = self.classifier(adversary_result['spot'])
                c_erase = self.classifier(adversary_result['erase'])
                loss_a_spot = self.classifier.loss_bce(c_spot, label_c)
                loss_a_erase = torch.mean(c_erase[label_c > 0.5])

                # Get adversary final loss
                loss_a_final = loss_a_erase + loss_a_mask + loss_a_channel

                loss_a_final.backward()
                self.adversary.optimizer.step()

                wandb.log({
                    "step_count_a": self.step_count_a,
                    "loss_a_final": loss_a_final,
                    "loss_a_mask": loss_a_mask,
                    "loss_a_erase": loss_a_channel,
                    "loss_a_spot": loss_a_spot,
                    "loss_a_channel": loss_a_channel,
                    "score_iou": score_iou
                })

                # Visualize adversary progress
                if self.step_count_a % 10 == 0:
                    image = self.demo_inputs['image'].clone()
                    label = self.dmeo_labels['classification'].clone()
                    image = move_to(image, self.device)
                    label = move_to(label, self.device)

                    adversary_result = self.adversary(image, label)

                    for typez in [
                            'vis_output', 'vis_mask', 'vis_erase', 'vis_spot'
                    ]:
                        output = adversary_result[typez]
                        for i, o in enumerate(output):
                            cv2.imwrite(
                                artifact_manager.getDir() +
                                f'/{typez}_{i}_{self.step_count_vis}.png',
                                o * 255)

                    self.step_count_vis += 1

            # Clear gradients
            self.classifier.optimizer.zero_grad()
            self.adversary.optimizer.zero_grad()

            cv2.imshow(
                'image',
                np.moveaxis(image[0].clone().detach().cpu().numpy(), 0, -1))
            cv2.imshow(
                'label',
                label_to_image(label_s[0].clone().detach().cpu().numpy()))
            cv2.imshow('output', adversary_result['vis_output'][0])
            cv2.imshow('mask', adversary_result['vis_mask'][0])
            cv2.imshow('erase', adversary_result['vis_erase'][0])
            cv2.imshow('spot', adversary_result['vis_spot'][0])

            cv2.waitKey(1)
Esempio n. 8
0
 def save(self, tag=""):
     weight_path = artifact_manager.getDir() + self.name + "_weights" + tag + ".pt"
     torch.save(self.state_dict(), weight_path)
     print('Saved Model: ', weight_path)
Esempio n. 9
0
 def load(self, tag=""):
     weight_path = artifact_manager.getDir() + self.name + "_weights" + tag + ".pt"
     # print(self.state_dict().keys())
     # files = torch.load(weight_path)
     # print(files.keys())
     self.load_state_dict(torch.load(weight_path))
Esempio n. 10
0
def save_labels(
    dataset_root,
    model_name,
    batch_size=8,
    image_size=256,
    use_gt_labels=False,
):
    print('Save cams : ', locals())
    import shutil
    import cv2
    import os
    import numpy as np
    from torch.utils.data.dataloader import DataLoader
    from data.loader_segmentation import Segmentation
    from artifacts.artifact_manager import artifact_manager

    # Set up model
    model = get_model(model_name)
    model.load()
    model.to(model.device)

    # Set up data loader
    dataloader = DataLoader(
        Segmentation(
            dataset_root,
            source='trainval',
            source_augmentation='val',
            image_size=image_size,
            requested_labels=['classification', 'segmentation']
        ),
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
    )

    # Clear and create desintation directory
    cam_path = os.path.join(artifact_manager.getDir(), 'cam')
    if (os.path.exists(cam_path)):
        shutil.rmtree(cam_path)
    os.makedirs(cam_path)

    for batch_no, batch in enumerate(dataloader):
        inputs_in = batch[0]
        labels_in = batch[1]
        datapacket_in = batch[2]

        # Run images through model and get raw cams
        with torch.no_grad():
            cams = model.event({
                'name': 'get_cam',
                'inputs': inputs_in,
                'labels': labels_in,
                'batch': batch_no+1
            })

        # Save out cams
        for cam_no, cam in enumerate(cams):
            # Save out ground truth labels for testing the rest of the system
            if use_gt_labels:
                cam = labels_in['segmentation'][cam_no][1:]
                cam = F.adaptive_avg_pool2d(cam, [32, 32]).numpy()

                for i in range(0, cam.shape[0]):
                    cam[i] = cv2.blur(cam[i], (3, 3))
                    cam[i] = cv2.blur(cam[i], (3, 3))

            # Disregard false positives
            gt_mask = labels_in['classification'][cam_no].numpy()
            gt_mask[gt_mask > 0.5] = 1
            gt_mask[gt_mask <= 0.5] = 0
            gt_mask = np.expand_dims(np.expand_dims(gt_mask, -1), -1)
            cam *= gt_mask

            # Upsample CAM to original image size
            # - Calculate original image aspect ratio
            width = datapacket_in['width'][cam_no].detach().numpy()
            height = datapacket_in['height'][cam_no].detach().numpy()
            aspect_ratio = width / height

            # - Calculate width and height to cut from upscaled CAM
            if aspect_ratio > 1:
                cut_width = image_size
                cut_height = round(image_size / aspect_ratio)
            else:
                cut_width = round(image_size * aspect_ratio)
                cut_height = image_size

            # - Upscale CAM to match input size
            cam = np.moveaxis(cam, 0, -1)
            cam = cv2.resize(cam, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
            cam = np.moveaxis(cam, -1, 0)

            # - Cut CAM from input size and upscale to original image size 
            cam = cam[:, 0:cut_height, 0:cut_width]
            cam = np.moveaxis(cam, 0, -1)
            cam = cv2.resize(cam, (width, height), interpolation=cv2.INTER_LINEAR)
            cam = np.moveaxis(cam, -1, 0)

            # Normalize each cam map to between 0 and 1
            cam_max = np.max(cam, (1, 2), keepdims=True)
            cam_norm = cam / (cam_max + 1e-5)

            # Collapse cam from 3d into long 2d
            cam_norm = np.reshape(cam_norm, (cam_norm.shape[0] * cam_norm.shape[1], cam_norm.shape[2]))
            cam_norm[cam_norm > 1] = 1
            cam_norm[cam_norm < 0] = 0

            # Write image
            img_no = datapacket_in['image_name'][cam_no]
            cv2.imwrite(os.path.join(cam_path, img_no) + '.png', cam_norm * 255)
            print('Save cam : ', img_no, end='\r')
    print('')
Esempio n. 11
0
def save_cams(config: Config):
    config_json = config.toDictionary()
    print('save_cams')
    print(config_json)
    import shutil
    import cv2
    import os
    import numpy as np
    from torch.utils.data.dataloader import DataLoader
    from data.loader_segmentation import Segmentation
    from artifacts.artifact_manager import artifact_manager

    # Set up model
    model = get_model(config.classifier_name)
    model.load()
    model.eval()
    model.to(model.device)

    # Set up data loader
    dataloader = DataLoader(Segmentation(
        config.classifier_dataset_root,
        source='train',
        augmentation='val',
        image_size=config.classifier_image_size,
        requested_labels=['classification', 'segmentation']),
                            batch_size=config.cams_produce_batch_size,
                            shuffle=False,
                            num_workers=4,
                            prefetch_factor=4)

    # Clear and create destination directory
    cam_path = os.path.join(artifact_manager.getDir(), 'cam')
    if (os.path.exists(cam_path)):
        shutil.rmtree(cam_path)
    os.makedirs(cam_path)

    label_cam_path = os.path.join(artifact_manager.getDir(), 'labels_cam')
    if (os.path.exists(label_cam_path)):
        shutil.rmtree(label_cam_path)
    os.makedirs(label_cam_path)

    for batch_no, batch in enumerate(dataloader):
        inputs_in = batch[0]
        labels_in = batch[1]
        datapacket_in = batch[2]

        # Run images through model and get raw cams
        with torch.no_grad():
            cams = model.event({
                'name': 'get_cam',
                'inputs': inputs_in,
                'labels': labels_in,
                'batch': batch_no + 1
            })

        # Save out cams
        for cam_no, cam in enumerate(cams):
            # Save out ground truth labels for testing the rest of the system
            if config.cams_save_gt_labels:
                cam = labels_in['segmentation'][cam_no][1:]
                cam = F.adaptive_avg_pool2d(cam, [32, 32]).numpy()

                for i in range(0, cam.shape[0]):
                    cam[i] = cv2.blur(cam[i], (3, 3))
                    cam[i] = cv2.blur(cam[i], (3, 3))

            # Disregard false positives
            gt_mask = labels_in['classification'][cam_no].numpy()
            gt_mask[gt_mask > 0.5] = 1
            gt_mask[gt_mask <= 0.5] = 0
            gt_mask = np.expand_dims(np.expand_dims(gt_mask, -1), -1)
            cam *= gt_mask

            # Scale CAM to match input size
            cam = np.moveaxis(cam, 0, -1)
            cam = cv2.resize(
                cam,
                (config.classifier_image_size, config.classifier_image_size),
                interpolation=cv2.INTER_LINEAR)
            cam = np.moveaxis(cam, -1, 0)

            # - Cut CAM from input size and upscale to original image size
            width = datapacket_in['width'][cam_no].detach().numpy()
            height = datapacket_in['height'][cam_no].detach().numpy()
            content_width = datapacket_in['content_width'][cam_no].detach(
            ).numpy()
            content_height = datapacket_in['content_height'][cam_no].detach(
            ).numpy()
            cam = cam[:, 0:content_height, 0:content_width]
            cam = np.moveaxis(cam, 0, -1)
            cam = cv2.resize(cam, (width, height),
                             interpolation=cv2.INTER_LINEAR)
            cam = np.moveaxis(cam, -1, 0)

            # Normalize each cam map to between 0 and 1
            cam_max = np.max(cam, (1, 2), keepdims=True)
            cam_norm = cam / (cam_max + 1e-5)

            cam_bg = (
                1 -
                np.max(cam_norm, axis=0, keepdims=True))**config.cams_bg_alpha
            cam_with_bg = np.concatenate((cam_bg, cam_norm), axis=0)
            label_cam = label_to_image(cam_with_bg)

            # Collapse cam from 3d into long 2d
            cam_norm = np.reshape(
                cam_norm,
                (cam_norm.shape[0] * cam_norm.shape[1], cam_norm.shape[2]))
            cam_norm[cam_norm > 1] = 1
            cam_norm[cam_norm < 0] = 0
            label_cam[label_cam > 1] = 1
            label_cam[label_cam < 0] = 0

            # Write image
            img_no = datapacket_in['image_name'][cam_no]
            cv2.imwrite(
                os.path.join(cam_path, img_no) + '.png', cam_norm * 255)
            cv2.imwrite(
                os.path.join(label_cam_path, img_no) + '.png', label_cam * 255)
            print('Save cam : ', img_no, end='\r')
    print('')