示例#1
0
def save_images(image, mask, output_path, image_file, palette, original_size):
    # Saves the image, the model output and the results after the post processing
    zero_pad = 256 * 3 - len(palette)
    for i in range(zero_pad):
        palette.append(0)

    w, h = image.size

    if original_size:
        w, h = original_size

    image_file = os.path.basename(image_file).split('.')[0]
    colorized_mask = colorize_mask(mask, palette)

    if image.size != original_size:
        image = image.resize(size=original_size, resample=Image.BILINEAR)
    if colorized_mask.size != original_size:
        colorized_mask = colorized_mask.resize(size=original_size,
                                               resample=Image.NEAREST)

    blend = Image.blend(image, colorized_mask.convert('RGB'), 0.5)

    colorized_mask.save(os.path.join(output_path, image_file + '.png'))
    output_im = Image.new('RGB', (w * 3, h))
    output_im.paste(image, (0, 0))
    output_im.paste(colorized_mask, (w * 1, 0))
    output_im.paste(blend, (w * 2, 0))
    output_im.save(os.path.join(output_path, image_file + '_colorized.png'))
示例#2
0
def save_images(output, output_path, name, palette):
    # Saves the image, the model output and the results after the post processing
    mask = output.detach().squeeze(0).cpu().numpy()
    mask = F.softmax(torch.from_numpy(mask), dim=0).argmax(0).cpu().numpy()
    w, h = mask.shape
    colorized_mask = colorize_mask(mask, palette)
    colorized_mask.save(os.path.join(output_path, name + '.png'))
示例#3
0
def save_images(image, mask, output_path, image_file, palette):
    # Saves the image, the model output and the results after the post processing
    w, h = image.size
    image_file = os.path.basename(image_file).split('.')[0]
    colorized_mask = colorize_mask(mask, palette)
    colorized_mask.save(os.path.join(output_path, image_file + '_mask.png'))
    image.save(os.path.join(output_path, image_file + '_orig.png'))
示例#4
0
def main():
    args = parse_arguments()

    # CONFIG
    assert args.config
    config = json.load(open(args.config))
    scales = [0.5, 0.75, 1.0, 1.25, 1.5]

    # DATA
    testdataset = testDataset(args.images)
    loader = DataLoader(testdataset,
                        batch_size=1,
                        shuffle=False,
                        num_workers=1)
    num_classes = config['num_classes']
    palette = get_voc_pallete(num_classes)

    # MODEL
    config['model']['supervised'] = True
    config['model']['semi'] = False
    model = models.CCT(num_classes=num_classes,
                       conf=config['model'],
                       testing=True)
    checkpoint = torch.load(args.model)
    model = torch.nn.DataParallel(model)
    try:
        model.load_state_dict(checkpoint['state_dict'], strict=True)
    except Exception as e:
        print(f'Some modules are missing: {e}')
        model.load_state_dict(checkpoint['state_dict'], strict=False)
    model.eval()
    model.cuda()

    #if args.save and not os.path.exists('outputs'):
    #    os.makedirs('outputs')
    if not os.path.exists(args.save):
        os.makedirs(args.save)

    # LOOP OVER THE DATA
    tbar = tqdm(loader, ncols=100)
    total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
    labels, predictions = [], []

    for index, data in enumerate(tbar):
        image, image_id = data
        image = image.cuda()

        # PREDICT
        with torch.no_grad():
            output = multi_scale_predict(model, image, scales, num_classes)
        prediction = np.asarray(np.argmax(output, axis=0), dtype=np.uint8)

        # SAVE RESULTS
        prediction_im = colorize_mask(prediction, palette)
        prediction_im.save(args.save + '/' + image_id[0] + '.png')
示例#5
0
 def _add_img_tb(self, val_visual, wrt_mode):
     val_img = []
     palette = self.val_loader.dataset.palette
     for imgs in val_visual:
         imgs = [self.restore_transform(i) if (isinstance(i, torch.Tensor) and len(i.shape) == 3) 
                     else colorize_mask(i, palette) for i in imgs]
         imgs = [i.convert('RGB') for i in imgs]
         imgs = [self.viz_transform(i) for i in imgs]
         val_img.extend(imgs)
     val_img = torch.stack(val_img, 0)
     val_img = make_grid(val_img.cpu(), nrow=val_img.size(0)//len(val_visual), padding=5)
     self.writer.add_image(f'{wrt_mode}/inputs_targets_predictions', val_img, self.wrt_step)
示例#6
0
def save_image(image, epoch, id, palette):
    with torch.no_grad():
        if image.shape[0] == 3:
            restore_transform = transforms.Compose([
            DeNormalize(IMG_MEAN),
            transforms.ToPILImage()])

            image = restore_transform(image)
            #image = PIL.Image.fromarray(np.array(image)[:, :, ::-1])  # BGR->RGB
            image.save(os.path.join('../visualiseImages/', str(epoch)+ id + '.png'))
        else:
            mask = image.numpy()
            colorized_mask = colorize_mask(mask, palette)
            colorized_mask.save(os.path.join('../visualiseImages/', str(epoch)+ id + '.png'))
示例#7
0
def save_images(image, mask, output_path, image_file, palette, original_size, output=None):
	# Saves the image, the model output and the results after the post processing
    zero_pad = 256 * 3 - len(palette)
    for i in range(zero_pad):
        palette.append(0)

    w, h = image.size

    if original_size:
        w, h =original_size

    if output:
        print(mask.shape)
        resize_mask = cv2.resize(mask, dsize=original_size, interpolation=cv2.INTER_NEAREST)
        print(resize_mask.shape)
        pc_0 = int(np.count_nonzero(resize_mask==0))
        pc_1 = int(np.count_nonzero(resize_mask==1))
        pc_2 = int(np.count_nonzero(resize_mask==2))
        pc_3 = int(np.count_nonzero(resize_mask==3))
        pc_total = pc_0 + pc_1 + pc_2 + pc_3
        output["pc_0"] = pc_0
        output["pc_1"] = pc_1
        output["pc_2"] = pc_2
        output["pc_3"] = pc_3
        output["pc_total"] = pc_total

    image_file = os.path.basename(image_file).split('.')[0]
    colorized_mask = colorize_mask(mask, palette)

    if image.size != original_size:
        image = image.resize(size=original_size, resample=Image.BILINEAR)
    if colorized_mask.size != original_size:
        colorized_mask = colorized_mask.resize(size=original_size, resample=Image.NEAREST)

    blend = Image.blend(image, colorized_mask.convert('RGB'), 0.5)

    mask_path = os.path.join(output_path, image_file+'.png')
    
    colorized_mask.save(mask_path)
    output_im = Image.new('RGB', (w*3, h))
    output_im.paste(image, (0,0))
    output_im.paste(colorized_mask, (w*1,0))
    output_im.paste(blend, (w*2,0))
    blend_path = os.path.join(output_path, image_file+'_colorized.png')
    output_im.save(blend_path)
    
    if output:
        output['mask'] = mask_path
        output['blend'] = blend_path
示例#8
0
def save_images(image, mask, output_path, image_file, palette):
    # Saves the image, the model output and the results after the post processing
    w, h = image.size
    image_file = os.path.basename(image_file).split('.')[0]
    colorized_mask = colorize_mask(mask, palette)
    pmask = np.array(colorized_mask)
    height, width = pmask.shape
    loc_x, loc_y = [], []
    for i in range(height):
        for j in range(width):
            if pmask[i, j] != 0:
                loc_y.append(i)
                loc_x.append(j)
    colorized_mask.save(os.path.join(output_path, image_file + '.png'))
    output_im = Image.new('RGB', (w * 2, h))
    output_im.paste(image, (0, 0))
    output_im.paste(colorized_mask, (w, 0))
    output_im.save(os.path.join(output_path, image_file + '_colorized.png'))
示例#9
0
    def _valid_epoch(self, epoch):
        if self.val_loader is None:
            self.logger.warning(
                'Not data loader was passed for the validation step, No validation is performed !'
            )
            return {}
        self.logger.info('\n###### EVALUATION ######')

        self.model.eval()
        self.wrt_mode = 'val'

        self._reset_metrics()
        tbar = tqdm(self.val_loader, ncols=130)
        with torch.no_grad():
            val_visual = []
            for batch_idx, (data, target) in enumerate(tbar):
                # data, target = data.to(self.device), target.to(self.device)
                # LOSS
                output = self.model(data)
                loss = self.loss(output, target)
                if isinstance(self.loss, torch.nn.DataParallel):
                    loss = loss.mean()
                self.total_loss.update(loss.item())

                seg_metrics = eval_metrics(output, target, self.num_classes)
                self._update_seg_metrics(*seg_metrics)

                # LIST OF IMAGE TO VIZ (15 images)
                if len(val_visual) < 15:
                    target_np = target.data.cpu().numpy()
                    output_np = output.data.max(1)[1].cpu().numpy()
                    val_visual.append(
                        [data[0].data.cpu(), target_np[0], output_np[0]])

                # PRINT INFO
                pixAcc, mIoU, _ = self._get_seg_metrics().values()
                tbar.set_description(
                    'EVAL ({}) | Loss: {:.3f}, PixelAcc: {:.2f}, Mean IoU: {:.2f} |'
                    .format(epoch, self.total_loss.average, pixAcc, mIoU))

            # WRTING & VISUALIZING THE MASKS
            val_img = []
            palette = self.train_loader.dataset.palette
            for d, t, o in val_visual:
                d = self.restore_transform(d)
                t, o = colorize_mask(t, palette), colorize_mask(o, palette)
                d, t, o = d.convert('RGB'), t.convert('RGB'), o.convert('RGB')
                [d, t, o] = [self.viz_transform(x) for x in [d, t, o]]
                val_img.extend([d, t, o])
            val_img = torch.stack(val_img, 0)
            val_img = make_grid(val_img.cpu(), nrow=3, padding=5)
            self.writer.add_image(
                '{}/inputs_targets_predictions'.format(self.wrt_mode), val_img,
                self.wrt_step)

            # METRICS TO TENSORBOARD
            self.wrt_step = (epoch) * len(self.val_loader)
            self.writer.add_scalar('{}/loss'.format(self.wrt_mode),
                                   self.total_loss.average, self.wrt_step)
            seg_metrics = self._get_seg_metrics()
            for k, v in list(seg_metrics.items())[:-1]:
                self.writer.add_scalar('{}/{}'.format(self.wrt_mode, k), v,
                                       self.wrt_step)

            log = {'val_loss': self.total_loss.average, **seg_metrics}

        return log
示例#10
0
    def _valid_epoch(self, epoch):
        if self.val_loader is None:
            self.logger.warning(
                'Not data loader was passed for the validation step, No validation is performed !'
            )
            return {}
        self.logger.info('\n###### EVALUATION ######')

        self.model.eval()
        self.wrt_mode = 'val'

        self._reset_metrics()
        tbar = tqdm(self.val_loader, ncols=130)
        with torch.no_grad():
            val_visual = []
            kernel_visual = []
            for batch_idx, (data, target, distance) in enumerate(tbar):
                output, distance_loss = self.model(data, distance)
                kernel, coarse_estimation, learned_sdf = self.model.module.kernel_visualization(
                    data)

                seg_loss = self.loss(output, target)
                loss = 5 * seg_loss + distance_loss
                self.total_loss.update(loss.item())

                seg_metrics = eval_metrics(output, target, self.num_classes)
                self._update_seg_metrics(*seg_metrics)

                # LIST OF IMAGE TO VIZ (15 images)
                if len(val_visual) < 15:
                    target_np = target.data.cpu().numpy()
                    kernel_np = kernel[:, 1, :, :, ].data.cpu().numpy()
                    output_np = output.data.max(1)[1].cpu().numpy()
                    #print(conved.shape)

                    learned_sdf_np = learned_sdf.data.cpu().numpy()  ##B*k*W*H
                    target_sdf_np = distance.data.cpu().numpy()  ##B*k*W*H

                    coarse_estimation_np = coarse_estimation[:, 1, :, :,
                                                             ].data.cpu(
                                                             ).numpy()

                    val_visual.append([
                        data[0].data.cpu(), target_np[0], output_np[0]
                    ])  ## RGB, Segmentation label, segmentation, target_sdf.
                    kernel_visual.append([
                        data[0].data.cpu(), kernel_np[0],
                        coarse_estimation_np[0], learned_sdf_np[0],
                        target_sdf_np[0]
                    ])
                # PRINT INFO
                pixAcc, mIoU, _ = self._get_seg_metrics().values()
                tbar.set_description(
                    'EVAL ({}) | Seg_Loss: {:.3f}, Dis_Loss: {:.3f},PixelAcc: {:.2f}, Mean IoU: {:.2f} |'
                    .format(epoch, seg_loss, distance_loss, pixAcc, mIoU))

            # WRTING & VISUALIZING THE MASKS
            val_img = []
            palette = self.train_loader.dataset.palette
            for rgb, label, seg in val_visual:
                rgb = self.restore_transform(rgb)
                label, seg = colorize_mask(label, palette), colorize_mask(
                    seg, palette)
                rgb, label, seg = rgb.convert('RGB'), label.convert(
                    'RGB'), seg.convert('RGB')
                [rgb, label,
                 seg] = [self.viz_transform(x) for x in [rgb, label, seg]]
                val_img.extend([rgb, label, seg])
            kernel_img = []

            for rgb, kernel, pro, learned_sdf, sdf in kernel_visual:  ## RGB, kerenl,coarse_prediction, learned_field, target_field
                rgb, kernel, pro = self.restore_transform(
                    rgb), self.kernel_transform(kernel), self.kernel_transform(
                        pro)
                learned_sdf, sdf = self.learned_transform(learned_sdf, sdf)
                learned_sdf, sdf = self.to_PIL(learned_sdf), self.to_PIL(sdf)

                rgb, kernel, pro,learned_sdf,sdf= rgb.convert('RGB'), kernel.convert('RGB'), pro.convert('RGB'),\
                                                      learned_sdf.convert('RGB'),sdf.convert('RGB')
                [rgb, kernel, pro, learned_sdf, sdf] = [
                    self.viz_transform(x)
                    for x in [rgb, kernel, pro, learned_sdf, sdf]
                ]
                kernel_img.extend([rgb, kernel, pro, learned_sdf, sdf])

            val_img = torch.stack(val_img, 0)
            val_img = make_grid(val_img.cpu(), nrow=3, padding=5)

            kernel_img = torch.stack(kernel_img, 0)
            kernel_img = make_grid(kernel_img.cpu(), nrow=5, padding=5)
            self.writer.add_image(
                f'{self.wrt_mode}/inputs_targets_predictions', val_img,
                self.wrt_step)
            self.writer.add_image(f'{self.wrt_mode}/kernel_predictions',
                                  kernel_img, self.wrt_step)
            # METRICS TO TENSORBOARD
            self.wrt_step = (epoch) * len(self.val_loader)
            self.writer.add_scalar(f'{self.wrt_mode}/loss', loss.item(),
                                   self.wrt_step)
            self.writer.add_scalar(f'{self.wrt_mode}/seg_loss', 5 * seg_loss,
                                   self.wrt_step)
            self.writer.add_scalar(f'{self.wrt_mode}/dis_loss', distance_loss,
                                   self.wrt_step)
            seg_metrics = self._get_seg_metrics()
            for k, v in list(seg_metrics.items())[:-1]:
                self.writer.add_scalar(f'{self.wrt_mode}/{k}', v,
                                       self.wrt_step)

            log = {'val_loss': self.total_loss.average, **seg_metrics}

        return log
示例#11
0
    def _valid_epoch(self, epoch):
        if self.val_loader is None:
            self.logger.warning(
                'Not data loader was passed for the validation step, No validation is performed !'
            )
            return {}
        self.logger.info('\n###### EVALUATION ######')

        self.model.eval()
        self.wrt_mode = 'val'

        self._reset_metrics()
        tbar = tqdm(self.val_loader, ncols=130)
        with torch.no_grad():
            val_visual = []
            cls_total_pix_correct = np.zeros(self.num_classes)
            cls_total_pix_labeled = np.zeros(self.num_classes)
            for batch_idx, (data, target) in enumerate(tbar):
                #data, target = data.to(self.device), target.to(self.device)
                # LOSS
                output = self.model(data)
                if self.config['arch']['type'][:2] == 'IC':
                    assert output[0].size()[2:] == target.size()[1:]
                    assert output[0].size()[1] == self.num_classes
                    loss = self.loss(output, target)
                    output = output[0]
                elif self.config['arch']['type'][-3:] == 'OCR':
                    assert output[0].size()[2:] == target.size()[1:]
                    assert output[0].size()[1] == self.num_classes
                    loss = self.loss(output[0], target)
                    loss += self.loss(output[1], target) * 0.4
                    output = output[0]
                elif 'Nearest' in self.config['arch']['type']:
                    assert output[0].size()[2:] == target.size()[1:]
                    assert output[0].size()[1] == self.num_classes
                    loss = self.loss(output[0], target)
                    loss += self.loss(output[1], target) * 0.4
                    output = output[0]
                elif self.config['arch']['type'][:3] == 'Enc':
                    assert output[0].size()[2:] == target.size()[1:]
                    assert output[0].size()[1] == self.num_classes
                    loss = self.loss(output, target)
                    output = output[0]
                elif self.config['arch']['type'][:5] == 'DANet':
                    assert output[0].size()[2:] == target.size()[1:]
                    assert output[0].size()[1] == self.num_classes
                    loss = self.loss(output[0], target)
                    loss += self.loss(output[1], target) * 0.2
                    loss += self.loss(output[2], target) * 0.2
                    output = output[0]
                else:
                    assert output.size()[2:] == target.size()[1:]
                    assert output.size()[1] == self.num_classes
                    loss = self.loss(output, target)

                if isinstance(self.loss, torch.nn.DataParallel):
                    loss = loss.mean()
                self.total_loss.update(loss.item())

                seg_metrics = eval_metrics(output, target, self.num_classes)
                self._update_seg_metrics(*seg_metrics)

                for i in range(self.num_classes):
                    cls_pix_correct, cls_pix_labeled = batch_class_pixel_accuracy(
                        output, target, i)
                    cls_total_pix_correct[i] += cls_pix_correct
                    cls_total_pix_labeled[i] += cls_pix_labeled
                # LIST OF IMAGE TO VIZ (15 images)
                if len(val_visual) < 15:
                    target_np = target.data.cpu().numpy()
                    output_np = output.data.max(1)[1].cpu().numpy()
                    val_visual.append(
                        [data[0].data.cpu(), target_np[0], output_np[0]])

                # PRINT INFO
                pixAcc, mIoU, _ = self._get_seg_metrics().values()
                cls_pix_acc = np.round(
                    cls_total_pix_correct / cls_total_pix_labeled, 3)
                tbar.set_description(
                    'EVAL ({}) | Loss: {:.3f}, PixelAcc: {:.2f}, Mean IoU: {:.2f}, cls_pix_acc: {} |'
                    .format(epoch, self.total_loss.average, pixAcc, mIoU,
                            str(cls_pix_acc)))

            # WRTING & VISUALIZING THE MASKS
            val_img = []
            palette = self.train_loader.dataset.palette
            for d, t, o in val_visual:
                d = self.restore_transform(d)
                t, o = colorize_mask(t, palette), colorize_mask(o, palette)
                d, t, o = d.convert('RGB'), t.convert('RGB'), o.convert('RGB')
                [d, t, o] = [self.viz_transform(x) for x in [d, t, o]]
                val_img.extend([d, t, o])
            val_img = torch.stack(val_img, 0)
            val_img = make_grid(val_img.cpu(), nrow=3, padding=5)
            self.writer.add_image(
                f'{self.wrt_mode}/inputs_targets_predictions', val_img,
                self.wrt_step)

            # METRICS TO TENSORBOARD
            self.wrt_step = (epoch) * len(self.val_loader)
            self.writer.add_scalar(f'{self.wrt_mode}/loss',
                                   self.total_loss.average, self.wrt_step)
            seg_metrics = self._get_seg_metrics()
            for k, v in list(seg_metrics.items())[:-1]:
                self.writer.add_scalar(f'{self.wrt_mode}/{k}', v,
                                       self.wrt_step)

            log = {'val_loss': self.total_loss.average, **seg_metrics}

        return log