def make_grid(batch_img, batch_mask, img_denormalize_fn, batch_gt_mask=None): """Create a grid from batch image and mask as img1 | img2 | img3 | img4 | ... i+m1 | i+m2 | i+m3 | i+m4 | ... mask1 | mask2 | mask3 | mask4 | ... i+M1 | i+M2 | i+M3 | i+M4 | ... Mask1 | Mask2 | Mask3 | Mask4 | ... i+m = image + mask blended with alpha=0.4 - maskN is predicted mask - MaskN is ground-truth mask if given Args: batch_img (torch.Tensor) batch of images of any type batch_mask (torch.Tensor) batch of masks img_denormalize_fn (Callable): function to denormalize batch of images batch_gt_mask (torch.Tensor, optional): batch of ground truth masks. """ assert isinstance(batch_img, torch.Tensor) and isinstance( batch_mask, torch.Tensor) assert len(batch_img) == len(batch_mask) if batch_gt_mask is not None: assert isinstance(batch_gt_mask, torch.Tensor) assert len(batch_mask) == len(batch_gt_mask) b = batch_img.shape[0] h, w = batch_img.shape[2:] le = 3 if batch_gt_mask is None else 3 + 2 out_image = np.zeros((h * le, w * b, 3), dtype="uint8") for i in range(b): img = batch_img[i] mask = batch_mask[i] img = img_denormalize_fn(img) img = tensor_to_rgb(img) mask = mask.cpu().numpy() mask = render_mask(mask) out_image[0:h, i * w:(i + 1) * w, :] = img out_image[1 * h:2 * h, i * w:(i + 1) * w, :] = render_datapoint(img, mask, blend_alpha=0.4) out_image[2 * h:3 * h, i * w:(i + 1) * w, :] = mask if batch_gt_mask is not None: gt_mask = batch_gt_mask[i] gt_mask = gt_mask.cpu().numpy() gt_mask = render_mask(gt_mask) out_image[3 * h:4 * h, i * w:(i + 1) * w, :] = render_datapoint(img, gt_mask, blend_alpha=0.4) out_image[4 * h:5 * h, i * w:(i + 1) * w, :] = gt_mask return out_image
def write_prediction_on_image2(mask_predicted, im, filepath): assert isinstance(mask_predicted, np.ndarray) and mask_predicted.ndim == 2, \ "{} and {}".format(type(mask_predicted), mask_predicted.shape if isinstance(mask_predicted, np.ndarray) else None) assert isinstance(im, np.ndarray) and im.ndim == 3, \ "{} and {}".format(type(im), im.shape if isinstance(im, np.ndarray) else None) # Normalize for rendering x = render_x(im) # Save the images and masks im = Image.fromarray(x).convert('RGB') pil_pred = Image.fromarray(mask_predicted.astype('uint8')) pil_pred.putpalette(vocpallete) pil_pred = pil_pred.convert('RGB') res_pred = render_datapoint(im, pil_pred) size_image = (mask_predicted.shape[0], mask_predicted.shape[1]) tiles = [[im, res_pred, pil_pred]] nb_cols = len(tiles) nb_rows = len(tiles[0]) cvs = Image.new('RGB', (nb_cols * size_image[0], nb_rows * size_image[1])) for i_row in range(nb_cols): for i_col in range(nb_rows): px, py = (i_row * size_image[0], i_col * size_image[1]) cvs.paste(tiles[i_row][i_col], (px, py)) cvs.save(filepath)
def wrapper(engine, writer, state_attr): output = engine.state.output x, y_pred, y = output['x'], output['y_pred'], output['y'] _, y_pred = torch.max(y_pred, dim=1) indices = torch.randperm(x.shape[0])[:n_images] bg_imgs = render_x(x[indices, ...], nrow=n_rows) pred_imgs = render_y(y_pred[indices, ...], nrow=n_rows) gt_imgs = render_y(y[indices, ...], nrow=n_rows) img1 = np.asarray(render_datapoint(bg_imgs, pred_imgs, output_size=single_img_size)) img2 = np.asarray(render_datapoint(bg_imgs, gt_imgs, output_size=single_img_size)) state = engine.state if another_engine is None else another_engine.state global_step = getattr(state, state_attr) writer.add_image(tag="predictions", img_tensor=img1, global_step=global_step, dataformats='HWC') writer.add_image(tag="ground-truth", img_tensor=img2, global_step=global_step, dataformats='HWC')
def make_grid( batch_img: torch.Tensor, batch_preds: torch.Tensor, img_denormalize_fn: Callable, batch_gt: Optional[torch.Tensor] = None, ): """Create a grid from batch image and mask as i+l1+gt1 | i+l2+gt2 | i+l3+gt3 | i+l4+gt4 | ... where i+l+gt = image + predicted label + ground truth Args: batch_img (torch.Tensor) batch of images of any type batch_preds (torch.Tensor) batch of masks img_denormalize_fn (Callable): function to denormalize batch of images batch_gt (torch.Tensor, optional): batch of ground truth masks. """ assert isinstance(batch_img, torch.Tensor) and isinstance(batch_preds, torch.Tensor) assert len(batch_img) == len(batch_preds), "{} vs {}".format( len(batch_img), len(batch_preds) ) assert batch_preds.ndim == 1, "{}".format(batch_preds.ndim) if batch_gt is not None: assert isinstance(batch_gt, torch.Tensor) assert len(batch_preds) == len(batch_gt) assert batch_gt.ndim == 1, "{}".format(batch_gt.ndim) b = batch_img.shape[0] h, w = batch_img.shape[2:] le = 1 out_image = np.zeros((h * le, w * b, 3), dtype="uint8") for i in range(b): img = batch_img[i] y_preds = batch_preds[i] img = img_denormalize_fn(img) img = tensor_to_numpy(img) pred_label = y_preds.cpu().item() target = "p={}".format(pred_label) if batch_gt is not None: gt_label = batch_gt[i] gt_label = gt_label.cpu().item() target += " | gt={}".format(gt_label) out_image[0:h, i * w : (i + 1) * w, :] = render_datapoint( img, target, text_size=12 ) return out_image
def write_prediction_on_image(y_pred, x, filepath, palette=default_palette): assert isinstance(y_pred, np.ndarray) and y_pred.ndim == 2, \ "{} and {}".format(type(y_pred), y_pred.shape if isinstance(y_pred, np.ndarray) else None) assert isinstance(x, np.ndarray) and x.ndim == 3, \ "{} and {}".format(type(x), x.shape if isinstance(x, np.ndarray) else None) pil_pred = Image.fromarray(y_pred) pil_pred.putpalette(palette) res = render_datapoint(x, pil_pred.convert('RGB'), blend_alpha=0.5) filepath = filepath + ".png" res.save(filepath)
def visualize(image_name, image_dir, mask_dir, pre_dir, save_name, dice=None): print("The {} visualization".format(save_name)) fontsize = 8 if 'mask' in image_name[0]: image_name = [i.split('_')[0] for i in image_name] length = len(image_name) images_path = find_imagenames(image_dir, image_name) masks_path = find_imagenames(mask_dir, image_name) pres_path = find_imagenames(pre_dir, image_name) f, ax = plt.subplots(length, 3, figsize=(10, 40)) if dice: for i in range(length): name = images_path[i].split('/')[-1].split('.')[0] original_image = cv2.imread(images_path[i]) original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB) dice_score = dice[name] w, h = original_image.shape[:2] ax[i, 0].imshow(original_image) ax[i, 0].set_title('Original image ' + name + ' ' + str(w) + '*' + str(h), fontsize=fontsize) original_mask = cv2.imread(masks_path[i]) # original_mask = cv2.cvtColor(original_mask,cv2.COLOR_BGR2GRAY) rimg = render_datapoint(original_mask, original_image, blend_alpha=0.8) ax[i, 1].imshow(rimg) ax[i, 1].set_title('Original mask', fontsize=fontsize) pre_mask = cv2.imread(pres_path[i]) # pre_mask = cv2.cvtColor(pre_mask,cv2.COLOR_BGR2GRAY) rimg2 = render_datapoint(pre_mask, original_image, blend_alpha=0.8) ax[i, 2].imshow(rimg2) ax[i, 2].set_title('pred mask' + ' dice:' + str(round(dice_score, 3)), fontsize=fontsize) plt.savefig(save_name + '.png') plt.show()