def forward(self, image, classification_label):
        result = self.segment(image, classification_label)

        # Generate erase mask
        segmentation = result["segmentation"]
        mask, _ = torch.max(segmentation[:, 1:], dim=1, keepdim=True)

        # Generate spot and erase images
        spot = image * mask
        erase = image * (1 - mask)

        result['spot'] = spot
        result['erase'] = erase
        result['mask'] = mask

        segmentation_np = segmentation.clone().detach().cpu().numpy()
        label_image_np = np.zeros((segmentation_np.shape[0], 256, 256, 3))

        for s in range(0, segmentation_np.shape[0]):
            label_image_np[s] = label_to_image(segmentation_np[s])

        result['vis_output'] = label_image_np
        result['vis_mask'] = np.moveaxis(mask.clone().detach().cpu().numpy(),
                                         1, -1)
        result['vis_erase'] = np.moveaxis(erase.clone().detach().cpu().numpy(),
                                          1, -1)
        result['vis_spot'] = np.moveaxis(spot.clone().detach().cpu().numpy(),
                                         1, -1)

        return result
    def event(self, event):
        if event['name'] == 'minibatch' and event['phase'] == 'train':
            image_cu = event['inputs']['image'].cuda(non_blocking=True)
            label_cu = event['labels']['segmentation'].cuda(non_blocking=True)
            label_cu = torch.argmax(label_cu, 1).long()

            segmentation_result = self.forward(image_cu)

            loss = self.loss_cce(segmentation_result, label_cu)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            image = event['inputs']['image'].detach().numpy()
            image = image[0]
            image = np.moveaxis(image, 0, -1)
            cv2.imshow('image', image)

            label = event['labels']['segmentation'].detach().numpy()
            label_vis = label[0]
            label_vis = label_to_image(label_vis)
            cv2.imshow('label', label_vis)

            prediction = torch.softmax(segmentation_result.detach(),
                                       1).cpu().numpy()
            prediction_vis = prediction[0]
            prediction_vis = label_to_image(prediction_vis)
            cv2.imshow('prediction', prediction_vis)

            cv2.waitKey(1)

            class_iou_mean = class_iou(prediction, label).mean()
            iou_result = iou(prediction, label)

            wandb.log({
                "class_iou_mean": class_iou_mean,
                "iou_result": iou_result,
                "loss": loss.detach().cpu().numpy(),
            })

        if event['name'] == 'epoch_end':
            print('')
            self.save()
Example #3
0
def _measure_sample(payload):
    count = payload['count']
    image_path = payload['image_path']
    cam_path = payload['cam_path']
    label_path = payload['label_path']
    predi_path = payload['predi_path']

    label = cv2.imread(label_path)
    predi = cv2.imread(predi_path)

    label = image_to_label(label)
    predi = image_to_label(predi)

    accuracy = metrics.accuracy_score(
        np.argmax(label, 0).flatten(),
        np.argmax(predi, 0).flatten())
    mapr = metrics.average_precision_score(label[1:].flatten(),
                                           predi[1:].flatten())
    miou = metrics.jaccard_score(np.argmax(label, 0).flatten(),
                                 np.argmax(predi, 0).flatten(),
                                 average='macro')

    log = {'accuracy': accuracy, 'mapr': mapr, 'miou': miou, 'count': count}

    if count < 8:
        raw = cv2.imread(cam_path)
        image = cv2.imread(image_path)
        image = image.astype(np.float32)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        predi = label_to_image(predi).astype(np.float32)
        predi = cv2.cvtColor(predi, cv2.COLOR_BGR2RGB)
        log['raw_' + str(count)] = wandb.Image(raw)
        log['img_' + str(count)] = wandb.Image(image)
        log['pred_' + str(count)] = wandb.Image(predi)
        log['count'] = 0

    return log
def save_semseg(
    dataset_root,
    model_name,
    batch_size=8,
    image_size=256,
    use_gt_labels=False,
):
    print('Save semseg : ', locals())
    import shutil
    import cv2
    import os
    import numpy as np
    import torch.nn.functional as F
    import torch
    from models.get_model import get_model
    from torch.utils.data.dataloader import DataLoader
    from data.loader_segmentation import Segmentation
    from artifacts.artifact_manager import artifact_manager
    from data.voc2012 import label_to_image

    # Set up model
    model = get_model(model_name)
    model.load()
    model.to(model.device)
    model.train(False)

    # Set up data loader
    dataloader = DataLoader(
        Segmentation(dataset_root,
                     source='val',
                     source_augmentation='val',
                     image_size=image_size,
                     requested_labels=['classification', 'segmentation']),
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        prefetch_factor=4,
    )

    # Clear and create destination directory
    semseg_path = os.path.join(artifact_manager.getDir(), 'semseg_output')
    if (os.path.exists(semseg_path)):
        shutil.rmtree(semseg_path)
    os.makedirs(semseg_path)

    for batch_no, batch in enumerate(dataloader):
        inputs_in = batch[0]
        labels_in = batch[1]
        datapacket_in = batch[2]

        # Run images through model and get raw cams
        with torch.no_grad():
            semsegs = model.event({
                'name': 'get_semseg',
                'inputs': inputs_in,
                'labels': labels_in,
                'batch': batch_no + 1
            })

            semsegs = semsegs.detach().cpu().numpy()

        # Save out cams
        for semseg_no, semseg in enumerate(semsegs):
            # Save out ground truth labels for testing the rest of the system
            if use_gt_labels:
                semseg = labels_in['segmentation'][semseg_no][1:]
                semseg = F.adaptive_avg_pool2d(semseg, [32, 32]).numpy()

                for i in range(0, semseg.shape[0]):
                    semseg[i] = cv2.blur(semseg[i], (3, 3))
                    semseg[i] = cv2.blur(semseg[i], (3, 3))

            # # Disregard false positives
            # gt_mask = labels_in['classification'][semseg_no].numpy()
            # gt_mask[gt_mask > 0.5] = 1
            # gt_mask[gt_mask <= 0.5] = 0
            # gt_mask = np.expand_dims(np.expand_dims(gt_mask, -1), -1)
            # cam *= gt_mask

            # Upsample CAM to original image size
            # - Calculate original image aspect ratio
            width = datapacket_in['width'][semseg_no].detach().numpy()
            height = datapacket_in['height'][semseg_no].detach().numpy()
            aspect_ratio = width / height

            # - Calculate width and height to cut from upscaled CAM
            if aspect_ratio > 1:
                cut_width = image_size
                cut_height = round(image_size / aspect_ratio)
            else:
                cut_width = round(image_size * aspect_ratio)
                cut_height = image_size

            # - Upscale CAM to match input size
            semseg = np.moveaxis(semseg, 0, -1)
            semseg = cv2.resize(semseg, (image_size, image_size),
                                interpolation=cv2.INTER_LINEAR)
            semseg = np.moveaxis(semseg, -1, 0)

            # - Cut CAM from input size and upscale to original image size
            semseg = semseg[:, 0:cut_height, 0:cut_width]
            semseg = np.moveaxis(semseg, 0, -1)
            semseg = cv2.resize(semseg, (width, height),
                                interpolation=cv2.INTER_LINEAR)
            semseg = np.moveaxis(semseg, -1, 0)

            semseg_as_label = label_to_image(semseg)

            # Write image
            img_no = datapacket_in['image_name'][semseg_no]
            cv2.imwrite(
                os.path.join(semseg_path, img_no) + '.png',
                semseg_as_label * 255)
            print('Save cam : ', img_no, end='\r')
    print('')
def save_cams_random_walk(config: Config):
    config_json = config.toDictionary()
    print('save_cams_random_walk')
    print(config_json)
    import shutil
    import os
    from torch.utils.data.dataloader import DataLoader
    from data.loader_segmentation import Segmentation
    from artifacts.artifact_manager import artifact_manager

    # Set up model
    model = get_model(config.affinity_net_name)
    model.load()
    model.eval()
    model.to(model.device)

    # Set up data loader
    dataloader = DataLoader(Segmentation(
        config.classifier_dataset_root,
        source='train',
        augmentation='affinity_predict',
        image_size=config.affinity_net_image_size,
        requested_labels=['classification', 'segmentation']),
                            batch_size=1,
                            shuffle=False,
                            num_workers=2,
                            prefetch_factor=2)

    # Get cam source directory
    cam_path = os.path.join(artifact_manager.getDir(), 'cam')

    # Clear and create output directory
    labels_rw_path = os.path.join(artifact_manager.getDir(), 'labels_rw')
    if (os.path.exists(labels_rw_path)):
        shutil.rmtree(labels_rw_path)
    os.makedirs(labels_rw_path)

    count = 0

    for batch_no, batch in enumerate(dataloader):
        inputs = batch[0]
        labels = batch[1]
        datapacket = batch[2]

        for image_no, image_name in enumerate(datapacket['image_name']):
            image = inputs['image'].cuda(non_blocking=True)
            image_width = datapacket['width'][image_no].numpy()
            image_height = datapacket['height'][image_no].numpy()
            channels = labels['classification'].shape[1]

            # Pad image
            image_width_padded = int(np.ceil(image_width / 8) * 8)
            image_height_padded = int(np.ceil(image_height / 8) * 8)
            image_padded = F.pad(image,
                                 (0, image_width_padded - image_width, 0,
                                  image_height_padded - image_height))

            image_width_pooled = int(np.ceil(image_width_padded / 8))
            image_height_pooled = int(np.ceil(image_height_padded / 8))

            # Load cam
            cam_path_instance = os.path.join(cam_path, image_name + '.png')
            cam = cv2.imread(cam_path_instance, cv2.IMREAD_GRAYSCALE)
            cam = np.reshape(cam, ((channels, image_height, image_width)))
            cam = cam / 255.0

            # Build cam background
            cam_background = (
                1 - np.max(cam,
                           (0), keepdims=True))**config.affinity_net_bg_alpha
            cam = np.concatenate((cam_background, cam), axis=0)
            cam = cam.astype(np.float32)

            # Pad cam
            cam_padded_width = int(np.ceil(cam.shape[2] / 8) * 8)
            cam_padded_height = int(np.ceil(cam.shape[1] / 8) * 8)
            cam_padded = np.pad(cam,
                                ((0, 0), (0, cam_padded_height - image_height),
                                 (0, cam_padded_width - image_width)),
                                mode='constant')

            # Run images through model and get affinity matrix
            with torch.no_grad():
                aff_mat = model.event({
                    'name': 'infer_aff_net_dense',
                    'image': image_padded,
                })
                aff_mat = torch.pow(aff_mat, config.affinity_net_beta)

            trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True)
            for _ in range(config.affinity_net_log_t):
                trans_mat = torch.matmul(trans_mat, trans_mat)

            cam_pooled = F.avg_pool2d(torch.from_numpy(cam_padded), 8, 8)

            cam_vec = cam_pooled.view(21, -1)

            cam_rw = torch.matmul(cam_vec.cuda(), trans_mat)
            cam_rw = cam_rw.view(1, 21, image_height_pooled,
                                 image_width_pooled)

            cam_rw = torch.nn.Upsample(
                (image_height_padded, image_width_padded),
                mode='bilinear')(cam_rw)
            cam_rw = cam_rw.cpu().data[0, :, :image_height, :image_width]

            label_rw = label_to_image(cam_rw)

            cv2.imwrite(os.path.join(labels_rw_path, image_name + '.png'),
                        label_rw * 255)

            count += 1
            print('Save cam : ', count, end='\r')

    print('')
Example #6
0
    def event(self, event):
        super().event(event)

        # if event['name'] == 'get_cam':
        #     image_cu = event['inputs']['image'].cuda(non_blocking=True)
        #     result_cu = self.segment(image_cu)
        #     return result_cu.detach().cpu().numpy()

        if event['name'] == 'minibatch':
            image_cu = event['inputs']['image'].cuda(non_blocking=True)
            label_classification_cu = event['labels']['classification'].cuda(
                non_blocking=True)

            # Run input through adversary
            adversary_result = self.adversary.segment(image_cu,
                                                      label_classification_cu)

            adversary_result = torch.sigmoid(
                adversary_result)  # TODO: try a linear normalization
            mask, _ = torch.max(adversary_result, dim=1, keepdim=True)
            erased = image_cu * (1 - mask)

            # Training controller
            self.step += 1
            if self.step % 2 == 0:
                self.step_count_c += 1

                # Train classifier
                if random.random() > 0.5:
                    classification = self.classifier(image_cu)
                else:
                    classification = self.classifier(erased)

                loss_c_bce = self.classifier.loss_bce(classification,
                                                      label_classification_cu)
                self.classifier.optimizer.zero_grad()
                loss_c_bce.backward()
                self.classifier.optimizer.step()

            else:
                self.step_count_a += 1

                # Train adversary
                # Channel loss
                adversary_classification = F.adaptive_avg_pool2d(
                    adversary_result, [1, 1])
                adversary_classification = torch.flatten(
                    adversary_classification, 1)
                loss_a_channel = self.adversary.loss_bce(
                    adversary_classification, label_classification_cu)

                # Classifier loss
                classification = self.classifier(erased)
                loss_a_classifier = torch.mean(
                    classification[label_classification_cu > 0.5])

                # Constrain loss
                loss_a_constrain = torch.mean(adversary_result)
                loss_a = loss_a_constrain * 0.3 + loss_a_channel * 0.45 + loss_a_classifier * 0.45

                self.adversary.optimizer.zero_grad()
                loss_a.backward()
                self.adversary.optimizer.step()

                wandb.log({
                    'step':
                    self.step,
                    'loss':
                    loss_a.detach().cpu().numpy(),
                    'loss_a_channel':
                    loss_a_channel.detach().cpu().numpy(),
                    'loss_a_classifier':
                    loss_a_classifier.detach().cpu().numpy(),
                    'loss_a_constrain':
                    loss_a_constrain.detach().cpu().numpy(),
                    # 'accuracy': metrics.accuracy_score(label.flatten(), predi.flatten()),
                    # 'f1': metrics.f1_score(label.flatten(), predi.flatten())
                })

            if self.step % 16 == 0:
                # label = label_classification_cu.detach().cpu().numpy()
                # predi = adversary_classification.detach().cpu().numpy().flatten()

                # predi[predi > 0.5] = 1
                # predi[predi <= 0.5] = 0

                erased = erased[0].detach().cpu().numpy()
                erased = np.moveaxis(erased, 0, -1)

                predi_vis = adversary_result[0].clone().detach().cpu().numpy()
                predi_vis_bg = np.power(
                    1 - np.max(predi_vis, axis=0, keepdims=True), 4)
                predi_vis = np.concatenate((predi_vis_bg, predi_vis), axis=0)
                cv2.imshow('predi', label_to_image(predi_vis))

                mask_vis = mask[0, 0].detach().cpu().numpy()
                mask_vis = mask_vis - np.min(mask_vis) / (np.max(mask_vis +
                                                                 1e-5))
                cv2.imshow('erased', erased)
                cv2.imshow('mask_vis', mask_vis)
                cv2.waitKey(1)

            #     # Segmentation loss
            #     segmentation_gen_label = adversary_result['segmentation'].clone().detach()
            #     segmentation_gen_label[:, 1:] *= torch.unsqueeze(torch.unsqueeze(classification_label, -1), -1)
            #     segmentation_gen_label[:, 0] = (1 - adversary_result['mask'][:, 0].detach())
            #     segmentation_gen_label_np = torch.softmax(segmentation_gen_label, dim=1).cpu().numpy()
            #     image_np = image.clone().detach().cpu().numpy()
            #     image_np = np.moveaxis(image_np, 1, -1)
            #     crf = CRF()
            #     result = crf.process(image_np[0], segmentation_gen_label_np[0])
            #     cv2.imshow('newlabel', label_to_image(result))
            #     loss_a_seg = self.adversary.loss_bce(adversary_result['segmentation'], segmentation_gen_label)

            #     # Constrain loss
            #     # loss_a_mask = torch.mean(adversary_result['segmentation']) * 0.01

            #     # Classifier loss
            #     # c_spot = self.classifier(adversary_result['spot'])
            #     # loss_a_spot = self.classifier.loss_bce(c_spot, classification_label)

            #     c_erase = self.classifier(adversary_result['erase'])
            #     loss_a_erase = torch.mean(c_erase[classification_label > 0.5]) * 0.1

            #     # Discrimination loss
            #     # discrimination = self.discriminator(adversary_result['mask'])
            #     # discrimination_label = torch.full((adversary_result['mask'].shape[0], 1), fill_value=0.9, device=self.device)
            #     # loss_a_disc = self.discriminator.loss_bce(discrimination, discrimination_label)

            #     # Get adversary final loss
            #     loss_a_final = loss_a_channel + loss_a_seg + loss_a_erase # + loss_a_mask # + loss_a_seg # + loss_a_disc loss_a_mask

            #     loss_a_final.backward()
            #     self.adversary.optimizer.step()

            #     wandb.log({
            #         "step_count_a": self.step_count_a,
            #         "loss_a_final": loss_a_final,
            #         # "loss_a_mask": loss_a_mask,
            #         # "loss_a_erase": loss_a_erase,
            #         # "loss_a_spot": loss_a_spot,
            #         "loss_a_channel": loss_a_channel,
            #         "score_iou": score_iou,
            #         "score_a_f1": f1(adversary_result['classification'].clone().detach().cpu().numpy(), classification_label.clone().detach().cpu().numpy()),
            #     })

            #     # Visualize adversary progress
            #     if self.step_count_a % 4 == 0:
            #         image = self.demo_inputs['image'].clone()
            #         label = self.dmeo_labels['classification'].clone()
            #         image = move_to(image, self.device)
            #         label = move_to(label, self.device)

            #         adversary_result = self.adversary(image, label)

            #         for typez in ['vis_output', 'vis_mask', 'vis_erase']:
            #             output = adversary_result[typez]
            #             for i, o in enumerate(output):
            #                 cv2.imwrite(artifact_manager.getDir() + f'/{typez}_{i}_{self.step_count_vis}.png', o * 255)

            #         self.step_count_vis += 1

            # self.discriminator.optimizer.zero_grad()

            # for i in range(adversary_result['segmentation'].shape[1]):
            #     cv2.imshow(str(i), adversary_result['segmentation'][0][i].clone().detach().cpu().numpy())

            # cv2.imshow('image', np.moveaxis(image[0].clone().detach().cpu().numpy(), 0, -1))
            # cv2.imshow('label', label_to_image(segmentation_label[0].clone().detach().cpu().numpy()))
            # cv2.imshow('output', adversary_result['vis_output'][0])
            # cv2.imshow('mask', adversary_result['vis_mask'][0])
            # cv2.imshow('erase', adversary_result['vis_erase'][0])

            # cv2.waitKey(1)

        # self.discriminator = Discriminator()
        # self.segmentation_loader = VOCSegmentation('train', dataset='voco')
        # pair = next(iter(DataLoader(VOCSegmentation('val', dataset='voco'), batch_size=32, shuffle=True, num_workers=0)))
        # self.demo_inputs = pair[0]
        # self.dmeo_labels = pair[1]


# class Discriminator(torch.nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.discriminator = torch.nn.Sequential(
#             torch.nn.Conv2d(1, 16, kernel_size=4, stride=2),
#             torch.nn.LeakyReLU(negative_slope=0.11),
#             torch.nn.Conv2d(16, 32, kernel_size=4, stride=2),
#             torch.nn.LeakyReLU(negative_slope=0.11),
#             torch.nn.Conv2d(32, 64, kernel_size=4, stride=2),
#             torch.nn.LeakyReLU(negative_slope=0.11),
#             torch.nn.Conv2d(64, 128, kernel_size=4, stride=2),
#             torch.nn.LeakyReLU(negative_slope=0.11),
#             torch.nn.Conv2d(128, 1, 1, padding=0),
#             torch.nn.Sigmoid(),
#             torch.nn.AdaptiveAvgPool2d(output_size=(1, 1)),
#             torch.nn.Flatten(1, 3),
#         )
#         self.loss_bce = torch.nn.BCELoss()
#         # self.optimizer = torch.optim.SGD(self.parameters(), lr=0.02, momentum=0.7)
#         self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0002)

#     def forward(self, image):
#         return self.discriminator(image)

# def gaus_kernel(shape=(3,3),sigma=10):
#     """
#     2D gaussian mask - should give the same result as MATLAB's
#     fspecial('gaussian',[shape],[sigma])
#     """
#     m,n = [(ss-1.)/2. for ss in shape]
#     y,x = np.ogrid[-m:m+1,-n:n+1]
#     h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
#     h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
#     sumh = h.sum()
#     if sumh != 0:
#         h /= sumh
#     return h

# class Blobber(torch.nn.Module):
#     def __init__(self):
#         super().__init__()
#         kernel_size = 27

#         kernel = gaus_kernel(shape=(kernel_size, kernel_size))
#         conv = torch.nn.Conv2d(1, 1, kernel_size, padding=(kernel_size-1)//2)
#         conv.bias.data.fill_(0)
#         conv.weight.data.copy_(torch.tensor(kernel))

#         self.blob_conv = conv

#     def blur(self, input):
#         s = input.clone()
#         for b in range(input.shape[0]):
#             for c in range(input.shape[1]):
#                 s[b, c] = self.blob_conv(input[b:b+1, c:c+1])[0]

#         return s

#     def forward(self, input):
#         s = self.blur(input)
#         return s

# wandb.log({
#     "step_count_c": self.step_count_c,
#     "loss_c_bce": loss_c_bce,
#     "score_iou": score_iou,
#     "score_c_f1": f1(classification.clone().detach().cpu().numpy(), classification_label.clone().detach().cpu().numpy()),
#     # "loss_d_bce": loss_d_bce,
#     # "score_d_f1": f1(discrimination.clone().detach().cpu().numpy(), discrimination_label.clone().detach().cpu().numpy())
# })

# Train discriminator
# discrimination_input = adversary_result['mask'].clone()
# discrimination_label = np.full((adversary_result['mask'].shape[0], 1), 0.1)

# for i in range(0, adversary_result['mask'].shape[0]):
#     if random.random() < 0.5:
#         continue
#     else:
#         _, label_dict, _ = self.segmentation_loader.__getitem__(random.randint(0, self.segmentation_loader.__len__() -1))
#         seg = label_dict['segmentation']
#         seg = np.max(seg[1:], axis=0)
#         discrimination_input[i] = torch.tensor(seg, device=self.device, dtype=torch.float).unsqueeze(0)
#         discrimination_label[i] = 0.9

# discrimination_label = torch.tensor(discrimination_label, device=self.device, dtype=torch.float)

# cv2.imshow('disc_mask', discrimination_input[0, 0].clone().detach().cpu().numpy())
# cv2.waitKey(1)

# discrimination = self.discriminator(discrimination_input)
# loss_d_bce = self.discriminator.loss_bce(discrimination, discrimination_label)
# loss_d_bce.backward()
# self.discriminator.optimizer.step()

# def forward(self, image, classification_label):
#     result = self.segment(image, classification_label)

#     # Generate erase mask
#     segmentation = result["segmentation"]
#     mask, _ = torch.max(segmentation[:, 1:], dim=1, keepdim=True)

#     # Generate spot and erase images
#     spot = image * mask
#     erase = image * (1 - mask)

#     result['spot'] = spot
#     result['erase'] = erase
#     result['mask'] = mask

#     segmentation_np = segmentation.clone().detach().cpu().numpy()
#     label_image_np = np.zeros((segmentation_np.shape[0], 256, 256, 3))

#     for s in range(0, segmentation_np.shape[0]):
#         label_image_np[s] = label_to_image(segmentation_np[s])

#     result['vis_output'] = label_image_np
#     result['vis_mask'] = np.moveaxis(mask.clone().detach().cpu().numpy(), 1, -1)
#     result['vis_erase'] = np.moveaxis(erase.clone().detach().cpu().numpy(), 1, -1)
#     result['vis_spot'] = np.moveaxis(spot.clone().detach().cpu().numpy(), 1, -1)

#     return result
    def event(self, event):
        super().event(event)

        if event['name'] == 'minibatch' and event['phase'] == 'train':
            image = event['inputs']['image']
            label_c = event['labels']['classification']
            label_s = event['labels']['segmentation']

            # Run input through adversary
            adversary_result = self.adversary(image, label_c)
            iou_label = label_s.clone().detach().cpu().numpy()
            iou_predi = adversary_result['segmentation'].clone().detach().cpu(
            ).numpy()
            max_indices = iou_predi.max(axis=1, keepdims=True) == iou_predi

            iou_predi = np.zeros(iou_predi.shape)
            iou_predi[max_indices] = 1
            score_iou = iou(iou_label[:, 1:], iou_predi[:, 1:])

            # Training controller
            self.step_count += 1
            if self.step == 'classifier':
                self.step_count_c += 1
                if self.step_count == 3:
                    self.step = 'adversary'
                    self.step_count = 0

                # Train classifier
                if random.random() > 0.5:
                    classification = self.classifier(image)
                else:
                    classification = self.classifier(adversary_result['erase'])
                loss_c_bce = self.classifier.loss_bce(classification, label_c)
                loss_c_bce.backward()
                self.classifier.optimizer.step()

                wandb.log({
                    "step_count_c": self.step_count_c,
                    "loss_c_bce": loss_c_bce,
                    "score_iou": score_iou
                })

            elif self.step == 'adversary':
                self.step_count_a += 1
                if self.step_count == 3:
                    self.step = 'classifier'
                    self.step_count = 0

                # Channel loss
                loss_a_channel = self.adversary.loss_bce(
                    adversary_result['classification'], label_c)

                # Constrain loss
                loss_a_mask = torch.mean(adversary_result['mask'])

                # Classifier loss
                c_spot = self.classifier(adversary_result['spot'])
                c_erase = self.classifier(adversary_result['erase'])
                loss_a_spot = self.classifier.loss_bce(c_spot, label_c)
                loss_a_erase = torch.mean(c_erase[label_c > 0.5])

                # Get adversary final loss
                loss_a_final = loss_a_erase + loss_a_mask + loss_a_channel

                loss_a_final.backward()
                self.adversary.optimizer.step()

                wandb.log({
                    "step_count_a": self.step_count_a,
                    "loss_a_final": loss_a_final,
                    "loss_a_mask": loss_a_mask,
                    "loss_a_erase": loss_a_channel,
                    "loss_a_spot": loss_a_spot,
                    "loss_a_channel": loss_a_channel,
                    "score_iou": score_iou
                })

                # Visualize adversary progress
                if self.step_count_a % 10 == 0:
                    image = self.demo_inputs['image'].clone()
                    label = self.dmeo_labels['classification'].clone()
                    image = move_to(image, self.device)
                    label = move_to(label, self.device)

                    adversary_result = self.adversary(image, label)

                    for typez in [
                            'vis_output', 'vis_mask', 'vis_erase', 'vis_spot'
                    ]:
                        output = adversary_result[typez]
                        for i, o in enumerate(output):
                            cv2.imwrite(
                                artifact_manager.getDir() +
                                f'/{typez}_{i}_{self.step_count_vis}.png',
                                o * 255)

                    self.step_count_vis += 1

            # Clear gradients
            self.classifier.optimizer.zero_grad()
            self.adversary.optimizer.zero_grad()

            cv2.imshow(
                'image',
                np.moveaxis(image[0].clone().detach().cpu().numpy(), 0, -1))
            cv2.imshow(
                'label',
                label_to_image(label_s[0].clone().detach().cpu().numpy()))
            cv2.imshow('output', adversary_result['vis_output'][0])
            cv2.imshow('mask', adversary_result['vis_mask'][0])
            cv2.imshow('erase', adversary_result['vis_erase'][0])
            cv2.imshow('spot', adversary_result['vis_spot'][0])

            cv2.waitKey(1)
    def event(self, event):
        if event['name'] == 'get_semseg':
            image_cu = event['inputs']['image'].cuda(non_blocking=True)
            return self.forward(image_cu)

        if event['name'] == 'minibatch' and event['phase'] == 'train':
            image_cu = event['inputs']['image'].cuda(non_blocking=True)
            label_cu = event['labels']['segmentation'].cuda(non_blocking=True)
            label_cu = torch.argmax(label_cu, 1).long()

            # cv2.imshow('amaxed', label_cu[0].detach().float().cpu().numpy() / 21)

            segmentation_result = self.forward(image_cu)
            loss = self.loss_cce(segmentation_result, label_cu)

            # self.optimizer.zero_grad()
            # loss.backward()
            # self.optimizer.step()

            if event['batch'] % 2 == 0:
                datapacket = event['data']
                image = event['inputs']['image'].detach().numpy()
                label = event['labels']['segmentation'].detach().numpy()
                pseudo = event['labels']['pseudo'].detach().numpy()
                predi = segmentation_result.detach().cpu().numpy()
                batch_size = image.shape[0]

                log = {
                    'loss': loss.detach().cpu().numpy(),
                    'acc': 0,
                    'mapr': 0,
                    'miou_macro': 0,
                    'p_miou_macro': 0,
                }

                for i in range(0, batch_size):
                    content_width = datapacket['content_width'][i].detach(
                    ).numpy()
                    content_height = datapacket['content_height'][i].detach(
                    ).numpy()

                    content_image = image[i, :, 0:content_height,
                                          0:content_width]
                    content_image_vis = np.moveaxis(content_image, 0, -1)

                    content_label = label[i, :, 0:content_height,
                                          0:content_width]
                    content_predi = predi[i, :, 0:content_height,
                                          0:content_width]

                    content_pseudo = pseudo[i, :, 0:content_height,
                                            0:content_width]

                    cv2.imshow('content_pseudo',
                               label_to_image(content_pseudo))
                    cv2.imshow('content_label', label_to_image(content_label))
                    cv2.imshow('content_predi', label_to_image(content_predi))
                    cv2.imshow('content_image', content_image_vis)
                    cv2.waitKey(1)

                    log['acc'] += metrics.accuracy_score(
                        np.argmax(content_label, 0).flatten(),
                        np.argmax(content_predi, 0).flatten())
                    log['mapr'] += metrics.average_precision_score(
                        content_label[1:].flatten(),
                        content_predi[1:].flatten())
                    log['miou_macro'] += metrics.jaccard_score(
                        np.argmax(content_label, 0).flatten(),
                        np.argmax(content_predi, 0).flatten(),
                        average='macro')
                    log['p_miou_macro'] += metrics.jaccard_score(
                        np.argmax(content_label, 0).flatten(),
                        np.argmax(content_pseudo, 0).flatten(),
                        average='macro')

                log['acc'] /= batch_size
                log['mapr'] /= batch_size
                log['miou_macro'] /= batch_size
                log['p_miou_macro'] /= batch_size

                wandb.log(log)

        if event['name'] == 'epoch_end':
            print('')
            self.save()
Example #9
0
def save_cams(config: Config):
    config_json = config.toDictionary()
    print('save_cams')
    print(config_json)
    import shutil
    import cv2
    import os
    import numpy as np
    from torch.utils.data.dataloader import DataLoader
    from data.loader_segmentation import Segmentation
    from artifacts.artifact_manager import artifact_manager

    # Set up model
    model = get_model(config.classifier_name)
    model.load()
    model.eval()
    model.to(model.device)

    # Set up data loader
    dataloader = DataLoader(Segmentation(
        config.classifier_dataset_root,
        source='train',
        augmentation='val',
        image_size=config.classifier_image_size,
        requested_labels=['classification', 'segmentation']),
                            batch_size=config.cams_produce_batch_size,
                            shuffle=False,
                            num_workers=4,
                            prefetch_factor=4)

    # Clear and create destination directory
    cam_path = os.path.join(artifact_manager.getDir(), 'cam')
    if (os.path.exists(cam_path)):
        shutil.rmtree(cam_path)
    os.makedirs(cam_path)

    label_cam_path = os.path.join(artifact_manager.getDir(), 'labels_cam')
    if (os.path.exists(label_cam_path)):
        shutil.rmtree(label_cam_path)
    os.makedirs(label_cam_path)

    for batch_no, batch in enumerate(dataloader):
        inputs_in = batch[0]
        labels_in = batch[1]
        datapacket_in = batch[2]

        # Run images through model and get raw cams
        with torch.no_grad():
            cams = model.event({
                'name': 'get_cam',
                'inputs': inputs_in,
                'labels': labels_in,
                'batch': batch_no + 1
            })

        # Save out cams
        for cam_no, cam in enumerate(cams):
            # Save out ground truth labels for testing the rest of the system
            if config.cams_save_gt_labels:
                cam = labels_in['segmentation'][cam_no][1:]
                cam = F.adaptive_avg_pool2d(cam, [32, 32]).numpy()

                for i in range(0, cam.shape[0]):
                    cam[i] = cv2.blur(cam[i], (3, 3))
                    cam[i] = cv2.blur(cam[i], (3, 3))

            # Disregard false positives
            gt_mask = labels_in['classification'][cam_no].numpy()
            gt_mask[gt_mask > 0.5] = 1
            gt_mask[gt_mask <= 0.5] = 0
            gt_mask = np.expand_dims(np.expand_dims(gt_mask, -1), -1)
            cam *= gt_mask

            # Scale CAM to match input size
            cam = np.moveaxis(cam, 0, -1)
            cam = cv2.resize(
                cam,
                (config.classifier_image_size, config.classifier_image_size),
                interpolation=cv2.INTER_LINEAR)
            cam = np.moveaxis(cam, -1, 0)

            # - Cut CAM from input size and upscale to original image size
            width = datapacket_in['width'][cam_no].detach().numpy()
            height = datapacket_in['height'][cam_no].detach().numpy()
            content_width = datapacket_in['content_width'][cam_no].detach(
            ).numpy()
            content_height = datapacket_in['content_height'][cam_no].detach(
            ).numpy()
            cam = cam[:, 0:content_height, 0:content_width]
            cam = np.moveaxis(cam, 0, -1)
            cam = cv2.resize(cam, (width, height),
                             interpolation=cv2.INTER_LINEAR)
            cam = np.moveaxis(cam, -1, 0)

            # Normalize each cam map to between 0 and 1
            cam_max = np.max(cam, (1, 2), keepdims=True)
            cam_norm = cam / (cam_max + 1e-5)

            cam_bg = (
                1 -
                np.max(cam_norm, axis=0, keepdims=True))**config.cams_bg_alpha
            cam_with_bg = np.concatenate((cam_bg, cam_norm), axis=0)
            label_cam = label_to_image(cam_with_bg)

            # Collapse cam from 3d into long 2d
            cam_norm = np.reshape(
                cam_norm,
                (cam_norm.shape[0] * cam_norm.shape[1], cam_norm.shape[2]))
            cam_norm[cam_norm > 1] = 1
            cam_norm[cam_norm < 0] = 0
            label_cam[label_cam > 1] = 1
            label_cam[label_cam < 0] = 0

            # Write image
            img_no = datapacket_in['image_name'][cam_no]
            cv2.imwrite(
                os.path.join(cam_path, img_no) + '.png', cam_norm * 255)
            cv2.imwrite(
                os.path.join(label_cam_path, img_no) + '.png', label_cam * 255)
            print('Save cam : ', img_no, end='\r')
    print('')
def generate():
    random.seed(2219677)
    # Set up bridge between coco and voc
    annFile = 'source/coco/annotations/instances_train2017.json'
    coco = COCO(annFile)
    bridge = VOC_COCO_Bridge(coco)

    # Get image id list
    img_ids, cats_ids = bridge.get_images()

    # Prepare destination folder
    dataset_name = 'voco'
    if path.isdir(f'generated/{dataset_name}'):
        shutil.rmtree(f'generated/{dataset_name}')

    os.makedirs(f'generated/{dataset_name}')
    os.makedirs(f'generated/{dataset_name}/images')
    os.makedirs(f'generated/{dataset_name}/labels')

    # Write images, and id lists
    train_file = open(f'generated/{dataset_name}/train.txt', 'w')
    val_file = open(f'generated/{dataset_name}/val.txt', 'w')

    # Write balanced sampling id list
    cats_ids_train = {}
    cats_ids_val = {}
    for cat_id in cats_ids.keys():
        cats_ids_train[cat_id] = []
        cats_ids_val[cat_id] = []
        for img_inx, img_id in enumerate(cats_ids[cat_id]):
            if img_inx % 10 == 0:
                cats_ids_val[cat_id].append(img_id)
            else:
                cats_ids_train[cat_id].append(img_id)
        
        class_count_current = len(cats_ids_train[cat_id])
        for i in range(class_count_current, class_count_target):
            cats_ids_train[cat_id].append(cats_ids_train[cat_id][random.randint(0, class_count_current-1)])

        print('train', cat_id, len(cats_ids_train[cat_id]), class_count_current)
        for img_id in cats_ids_train[cat_id]:
            train_file.write(f'{img_id}\n')

        for img_id in cats_ids_val[cat_id]:
            val_file.write(f'{img_id}\n')

    # Write images
    for img_index, img_id in enumerate(img_ids):
        img_filename = coco.loadImgs(ids = [img_id])[0]['file_name']
        img = cv2.imread(f'source/coco/train2017/{img_filename}')
        
        ans_ids = coco.getAnnIds(imgIds=[img_id])
        ans = coco.loadAnns(ans_ids)

        label = np.zeros((21, img.shape[0], img.shape[1]))
        for an in ans:
            an_cat_id = an['category_id']
            if an_cat_id in bridge.coco_ids:
                mask = coco.annToMask(an)
                voc_index = bridge.coco_ids.index(an_cat_id)
                label[voc_index+1] += mask
        label_rgb = label_to_image(label)

        cv2.imwrite(f'generated/{dataset_name}/images/{img_id}.jpg', img)
        cv2.imwrite(f'generated/{dataset_name}/labels/{img_id}.png', label_rgb * 255.0)
        print(img_index)

    train_file.close()
    val_file.close()