Ejemplo n.º 1
0
def make_grid(batch_img, batch_mask, img_denormalize_fn, batch_gt_mask=None):
    """Create a grid from batch image and mask as

        img1  | img2  | img3  | img4  | ...
        i+m1  | i+m2  | i+m3  | i+m4  | ...
        mask1 | mask2 | mask3 | mask4 | ...
        i+M1  | i+M2  | i+M3  | i+M4  | ...
        Mask1 | Mask2 | Mask3 | Mask4 | ...

        i+m = image + mask blended with alpha=0.4
        - maskN is predicted mask
        - MaskN is ground-truth mask if given

    Args:
        batch_img (torch.Tensor) batch of images of any type
        batch_mask (torch.Tensor) batch of masks
        img_denormalize_fn (Callable): function to denormalize batch of images
        batch_gt_mask (torch.Tensor, optional): batch of ground truth masks.
    """
    assert isinstance(batch_img, torch.Tensor) and isinstance(
        batch_mask, torch.Tensor)
    assert len(batch_img) == len(batch_mask)

    if batch_gt_mask is not None:
        assert isinstance(batch_gt_mask, torch.Tensor)
        assert len(batch_mask) == len(batch_gt_mask)

    b = batch_img.shape[0]
    h, w = batch_img.shape[2:]

    le = 3 if batch_gt_mask is None else 3 + 2
    out_image = np.zeros((h * le, w * b, 3), dtype="uint8")

    for i in range(b):
        img = batch_img[i]
        mask = batch_mask[i]

        img = img_denormalize_fn(img)
        img = tensor_to_rgb(img)
        mask = mask.cpu().numpy()
        mask = render_mask(mask)

        out_image[0:h, i * w:(i + 1) * w, :] = img
        out_image[1 * h:2 * h,
                  i * w:(i + 1) * w, :] = render_datapoint(img,
                                                           mask,
                                                           blend_alpha=0.4)
        out_image[2 * h:3 * h, i * w:(i + 1) * w, :] = mask

        if batch_gt_mask is not None:
            gt_mask = batch_gt_mask[i]
            gt_mask = gt_mask.cpu().numpy()
            gt_mask = render_mask(gt_mask)
            out_image[3 * h:4 * h,
                      i * w:(i + 1) * w, :] = render_datapoint(img,
                                                               gt_mask,
                                                               blend_alpha=0.4)
            out_image[4 * h:5 * h, i * w:(i + 1) * w, :] = gt_mask

    return out_image
Ejemplo n.º 2
0
def write_prediction_on_image2(mask_predicted, im, filepath):
    assert isinstance(mask_predicted, np.ndarray) and mask_predicted.ndim == 2, \
        "{} and {}".format(type(mask_predicted), mask_predicted.shape if isinstance(mask_predicted, np.ndarray) else None)

    assert isinstance(im, np.ndarray) and im.ndim == 3, \
        "{} and {}".format(type(im), im.shape if isinstance(im, np.ndarray) else None)

    # Normalize for rendering
    x = render_x(im)

    # Save the images and masks
    im = Image.fromarray(x).convert('RGB')

    pil_pred = Image.fromarray(mask_predicted.astype('uint8'))
    pil_pred.putpalette(vocpallete)
    pil_pred = pil_pred.convert('RGB')
    res_pred = render_datapoint(im, pil_pred)

    size_image = (mask_predicted.shape[0], mask_predicted.shape[1])

    tiles = [[im, res_pred, pil_pred]]

    nb_cols = len(tiles)
    nb_rows = len(tiles[0])

    cvs = Image.new('RGB', (nb_cols * size_image[0], nb_rows * size_image[1]))
    for i_row in range(nb_cols):
        for i_col in range(nb_rows):
            px, py = (i_row * size_image[0], i_col * size_image[1])
            cvs.paste(tiles[i_row][i_col], (px, py))

    cvs.save(filepath)
Ejemplo n.º 3
0
    def wrapper(engine, writer, state_attr):
        output = engine.state.output
        x, y_pred, y = output['x'], output['y_pred'], output['y']
        _, y_pred = torch.max(y_pred, dim=1)

        indices = torch.randperm(x.shape[0])[:n_images]
        bg_imgs = render_x(x[indices, ...], nrow=n_rows)
        pred_imgs = render_y(y_pred[indices, ...], nrow=n_rows)
        gt_imgs = render_y(y[indices, ...], nrow=n_rows)
        
        img1 = np.asarray(render_datapoint(bg_imgs, pred_imgs, output_size=single_img_size))
        img2 = np.asarray(render_datapoint(bg_imgs, gt_imgs, output_size=single_img_size))
        
        state = engine.state if another_engine is None else another_engine.state
        global_step = getattr(state, state_attr)
        
        writer.add_image(tag="predictions", img_tensor=img1, global_step=global_step, dataformats='HWC')
        writer.add_image(tag="ground-truth", img_tensor=img2, global_step=global_step, dataformats='HWC')
Ejemplo n.º 4
0
def make_grid(
    batch_img: torch.Tensor,
    batch_preds: torch.Tensor,
    img_denormalize_fn: Callable,
    batch_gt: Optional[torch.Tensor] = None,
):
    """Create a grid from batch image and mask as

        i+l1+gt1  | i+l2+gt2  | i+l3+gt3  | i+l4+gt4  | ...

        where i+l+gt = image + predicted label + ground truth

    Args:
        batch_img (torch.Tensor) batch of images of any type
        batch_preds (torch.Tensor) batch of masks
        img_denormalize_fn (Callable): function to denormalize batch of images
        batch_gt (torch.Tensor, optional): batch of ground truth masks.
    """
    assert isinstance(batch_img, torch.Tensor) and isinstance(batch_preds, torch.Tensor)
    assert len(batch_img) == len(batch_preds), "{} vs {}".format(
        len(batch_img), len(batch_preds)
    )
    assert batch_preds.ndim == 1, "{}".format(batch_preds.ndim)

    if batch_gt is not None:
        assert isinstance(batch_gt, torch.Tensor)
        assert len(batch_preds) == len(batch_gt)
        assert batch_gt.ndim == 1, "{}".format(batch_gt.ndim)

    b = batch_img.shape[0]
    h, w = batch_img.shape[2:]

    le = 1
    out_image = np.zeros((h * le, w * b, 3), dtype="uint8")

    for i in range(b):
        img = batch_img[i]
        y_preds = batch_preds[i]

        img = img_denormalize_fn(img)
        img = tensor_to_numpy(img)
        pred_label = y_preds.cpu().item()

        target = "p={}".format(pred_label)

        if batch_gt is not None:
            gt_label = batch_gt[i]
            gt_label = gt_label.cpu().item()
            target += " | gt={}".format(gt_label)

        out_image[0:h, i * w : (i + 1) * w, :] = render_datapoint(
            img, target, text_size=12
        )

    return out_image
Ejemplo n.º 5
0
def write_prediction_on_image(y_pred, x, filepath, palette=default_palette):
    assert isinstance(y_pred, np.ndarray) and y_pred.ndim == 2, \
        "{} and {}".format(type(y_pred), y_pred.shape if isinstance(y_pred, np.ndarray) else None)

    assert isinstance(x, np.ndarray) and x.ndim == 3, \
        "{} and {}".format(type(x), x.shape if isinstance(x, np.ndarray) else None)

    pil_pred = Image.fromarray(y_pred)
    pil_pred.putpalette(palette)
    res = render_datapoint(x, pil_pred.convert('RGB'), blend_alpha=0.5)

    filepath = filepath + ".png"
    res.save(filepath)
Ejemplo n.º 6
0
def visualize(image_name, image_dir, mask_dir, pre_dir, save_name, dice=None):
    print("The {} visualization".format(save_name))
    fontsize = 8
    if 'mask' in image_name[0]:
        image_name = [i.split('_')[0] for i in image_name]
    length = len(image_name)
    images_path = find_imagenames(image_dir, image_name)
    masks_path = find_imagenames(mask_dir, image_name)
    pres_path = find_imagenames(pre_dir, image_name)

    f, ax = plt.subplots(length, 3, figsize=(10, 40))
    if dice:
        for i in range(length):
            name = images_path[i].split('/')[-1].split('.')[0]
            original_image = cv2.imread(images_path[i])
            original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
            dice_score = dice[name]
            w, h = original_image.shape[:2]
            ax[i, 0].imshow(original_image)
            ax[i, 0].set_title('Original image ' + name + ' ' + str(w) + '*' +
                               str(h),
                               fontsize=fontsize)
            original_mask = cv2.imread(masks_path[i])
            # original_mask = cv2.cvtColor(original_mask,cv2.COLOR_BGR2GRAY)
            rimg = render_datapoint(original_mask,
                                    original_image,
                                    blend_alpha=0.8)
            ax[i, 1].imshow(rimg)
            ax[i, 1].set_title('Original mask', fontsize=fontsize)
            pre_mask = cv2.imread(pres_path[i])
            # pre_mask = cv2.cvtColor(pre_mask,cv2.COLOR_BGR2GRAY)
            rimg2 = render_datapoint(pre_mask, original_image, blend_alpha=0.8)
            ax[i, 2].imshow(rimg2)
            ax[i,
               2].set_title('pred mask' + ' dice:' + str(round(dice_score, 3)),
                            fontsize=fontsize)
        plt.savefig(save_name + '.png')
        plt.show()