def heatmaps_to_landmarks(self, hms): lms = np.zeros((len(hms), self.num_landmarks, 2), dtype=int) if hms.shape[1] > 3: # print(hms.max()) for i in range(len(hms)): heatmaps = to_numpy(hms[i]) for l in range(len(heatmaps)): hm = heatmaps[self.landmark_id_to_heatmap_id(l)] lms[i, l, :] = np.unravel_index(np.argmax(hm, axis=None), hm.shape)[::-1] elif hms.shape[1] == 3: hms = to_numpy(hms) def get_score_plane(h, lm_id, cn): v = utils.nn.lmcolors[lm_id, cn] hcn = h[cn] hcn[hcn < v - 2] = 0 hcn[hcn > v + 5] = 0 return hcn hms *= 255 for i in range(len(hms)): hm = hms[i] for l in landmarks.config.LANDMARKS: lm_score_map = get_score_plane(hm, l, 0) * get_score_plane( hm, l, 1) * get_score_plane(hm, l, 2) lms[i, l, :] = np.unravel_index( np.argmax(lm_score_map, axis=None), lm_score_map.shape)[::-1] lm_scale = lmcfg.HEATMAP_SIZE / self.input_size return lms / lm_scale
def _predict_center_crop(net, image, crop_size=544, gpu=True): h, w, c = image.shape image_probs = torch.zeros((h, w)) x = (w - crop_size) // 2 y = (h - crop_size) // 2 image_crop = image[y:y + crop_size, x:x + crop_size] input = _crop_to_tensor(image=image_crop)['image'] if gpu: input = input.cuda() with torch.no_grad(): t = time.time() crop_probs = net.forward(input.unsqueeze(0)) print(f'time forward: {int(1000 * (time.time() - t))}ms') show = False if show: disp_crop = vis.to_disp_image(input, denorm=True) fig, ax = plt.subplots(1, 2, sharex=True, sharey=True) ax[0].imshow(disp_crop) ax[1].imshow(to_numpy(crop_probs[0, 0]), cmap=plt.cm.viridis, vmin=0, vmax=1) plt.tight_layout() plt.show() image_probs[y:y + crop_size, x:x + crop_size] = crop_probs.squeeze().squeeze() return image, to_numpy(image_probs)
def predict_sequential(): image_probs = torch.zeros((h_pad, w_pad)) for ix in range(npx): for iy in range(npy): x = ix * inner_size y = iy * inner_size crop = image_pad[y:y + s, x:x + s] input = _crop_to_tensor(image=crop)['image'].cuda() with torch.no_grad(): crop_probs = net.forward(input.unsqueeze(0)) show = False if show: disp_crop = vis.to_disp_image(input, denorm=True) fig, ax = plt.subplots(1, 2, sharex=True, sharey=True) ax[0].imshow(disp_crop) ax[1].imshow(to_numpy(crop_probs[0, 0]), cmap=plt.cm.viridis, vmin=0, vmax=1) plt.tight_layout() plt.show() image_probs[y+d:y+d+inner_size, x+d:x+d+inner_size] = \ crop_probs.squeeze().squeeze()[d:d+inner_size, d:d+inner_size] return to_numpy(image_probs)
def show_segmentation_results(orig_image, recon, preds, gt_mask=None, foreground_mask=None, threshold=0.5): """ Show results for one image """ fig, ax = plt.subplots(2, 3, sharex=True, sharey=True) if torch.is_tensor(preds): preds = preds.squeeze().squeeze() else: preds = preds.squeeze() if foreground_mask is None: foreground_mask = np.ones_like(preds).astype(np.uint8) if gt_mask is not None: diff_map, _ = difference_map(gt_mask, preds, foreground_mask) else: diff_map = np.zeros_like(preds) if gt_mask is None: gt_mask = np.zeros_like(preds).astype(np.uint8) pred_mask = to_numpy((preds > threshold).squeeze()) gt_mask = to_numpy(gt_mask) # probs = np.clip(probs, a_min=0, a_max=1) imgfname = f'./outputs/results/{modelname}/hrf_{idx + 1:02d}_probs.png' io_utils.makedirs(imgfname) cv2.imwrite(imgfname, (preds * 255).astype(np.uint8)) imgfname = f'./outputs/results/{modelname}/hrf_{idx + 1:02d}_diff.png' cv2.imwrite(imgfname, cv2.cvtColor((diff_map).astype(np.uint8), cv2.COLOR_RGB2BGR)) imgfname = f'./outputs/results/{modelname}/hrf_{idx + 1:02d}_orig.png' cv2.imwrite(imgfname, cv2.cvtColor((orig_image).astype(np.uint8), cv2.COLOR_RGB2BGR)) ax[0, 0].imshow(orig_image) ax[0, 1].imshow(gt_mask) # ax[0,2].imshow(errors.astype(np.uint8)) ax[0, 2].imshow(diff_map.astype(np.uint8)) ax[1, 0].imshow(vis.to_disp_image(recon.squeeze(), denorm=True)) # ax[1,1].imshow(preds, vmin=-1, vmax=1) ax[1, 1].imshow(preds, vmax=1) ax[1, 2].imshow(pred_mask.astype(np.uint8)) plt.tight_layout()
def batch_predict(): image_probs = np.zeros((h_pad, w_pad)) inputs = [] for ix in range(npx): for iy in range(npy): x = ix * inner_size y = iy * inner_size crop = image_pad[y:y + s, x:x + s] input = _crop_to_tensor(image=crop)['image'] inputs.append(input) inputs = torch.stack(inputs) with torch.no_grad(): crop_probs = net.forward(inputs.cuda()) crop_probs = to_numpy(crop_probs) crop_id = 0 for ix in range(npx): for iy in range(npy): x = ix * inner_size y = iy * inner_size image_probs[y+d:y+d+inner_size, x+d:x+d+inner_size] = \ crop_probs[crop_id, 0, d:d+inner_size, d:d+inner_size] crop_id += 1 return image_probs
def overlay_vessels_heatmap(imgs, pred_vessel_hm): pred_vessel_hm = to_numpy(pred_vessel_hm) disp_X_recon_overlay = [ vis.overlay_heatmap(imgs[i], pred_vessel_hm[i, 0], 1.0) for i in range(len(pred_vessel_hm)) ] return disp_X_recon_overlay
def calc_landmark_recon_error(X, X_recon, lms, return_maps=False, reduction='mean'): assert len(X.shape) == 4 assert reduction in ['mean', 'none'] X = to_numpy(X) X_recon = to_numpy(X_recon) mask = np.zeros((X.shape[0], X.shape[2], X.shape[3]), dtype=np.float32) input_size = X.shape[-1] radius = input_size * 0.05 for img_id in range(len(mask)): for lm in lms[img_id]: cv2.circle(mask[img_id], (int(lm[0]), int(lm[1])), radius=int(radius), color=1, thickness=-1) err_maps = np.abs(X - X_recon).mean(axis=1) * 255.0 masked_err_maps = err_maps * mask debug = False if debug: fig, ax = plt.subplots(1, 3) ax[0].imshow( vis.to_disp_image( (X * mask[:, np.newaxis, :, :].repeat(3, axis=1))[0], denorm=True)) ax[1].imshow( vis.to_disp_image( (X_recon * mask[:, np.newaxis, :, :].repeat(3, axis=1))[0], denorm=True)) ax[2].imshow(masked_err_maps[0]) plt.show() if reduction == 'mean': err = masked_err_maps.sum() / (mask.sum() * 3) else: # err = masked_err_maps.mean(axis=2).mean(axis=1) err = masked_err_maps.sum(axis=2).sum( axis=1) / (mask.reshape(len(mask), -1).sum(axis=1) * 3) if return_maps: return err, masked_err_maps else: return err
def smooth_heatmaps(hms): assert (len(hms.shape) == 4) hms = to_numpy(hms) for i in range(hms.shape[0]): for l in range(hms.shape[1]): hms[i, l] = cv2.blur(hms[i, l], (9, 9), borderType=cv2.BORDER_CONSTANT) # hms[i,l] = cv2.GaussianBlur(hms[i,l], (9,9), sigmaX=9, borderType=cv2.BORDER_CONSTANT) return hms
def add_error_to_images(images, errors, loc='bl', size=0.65, vmin=0., vmax=30.0, thickness=1, format_string='{:.1f}', colors=None): new_images = to_disp_images(images) if colors is None: colors = color_map(to_numpy(errors), cmap=plt.cm.jet, vmin=vmin, vmax=vmax) if images[0].dtype == np.uint8: colors *= 255 for disp, err, color in zip(new_images, errors, colors): pos = get_pos_in_image(loc, size, disp.shape) cv2.putText(disp, format_string.format(err), pos, cv2.FONT_HERSHEY_DUPLEX, size, color, thickness, cv2.LINE_AA) return new_images
def add_confidences(disp_X_recon, lmids, loc): means = lm_confs[:, lmids].mean(axis=1) colors = vis.color_map(to_numpy(1 - means), cmap=plt.cm.jet, vmin=0.0, vmax=0.4) return vis.add_error_to_images(disp_X_recon, means, loc=loc, format_string='{:>4.2f}', colors=colors)
def draw_z(z_vecs): fy = 1 width = 10 z_zoomed = [] for lvl, _ft in enumerate(to_numpy(z_vecs)): # _ft = (_ft-_ft.min())/(_ft.max()-_ft.min()) vmin = 0 if lvl == 0 else -1 canvas = np.zeros((int(fy*len(_ft)), width, 3)) canvas[:int(fy*len(_ft)), :] = color_map(cv2.resize(_ft.reshape(-1,1), dsize=(width, int(fy*len(_ft))), interpolation=cv2.INTER_NEAREST), vmin=-1.0, vmax=1.0) z_zoomed.append(canvas) return make_grid(z_zoomed, nCols=len(z_vecs), padsize=1, padval=0).transpose((1,0,2))
def loss_struct(X, X_recon, torch_ssim, calc_error_maps=False, reduction='mean'): cs_error_maps = [] nimgs = len(X) errs = torch.zeros(nimgs, requires_grad=True).cuda() for i in range(nimgs): errs[i] = 1.0 - torch_ssim(X[i].unsqueeze(0), X_recon[i].unsqueeze(0)) if calc_error_maps: cs_error_maps.append(1.0 - to_numpy(torch_ssim.cs_map)) loss = __reduce(errs, reduction) if calc_error_maps: return loss, np.vstack(cs_error_maps) else: return loss, None
def heatmaps_to_landmarks(hms, target_size): def landmark_id_to_heatmap_id(lm_id): return {lm: i for i, lm in enumerate(range(num_landmarks))}[lm_id] assert len(hms.shape) == 4 num_images = hms.shape[0] num_landmarks = hms.shape[1] heatmap_size = hms.shape[-1] lms = np.zeros((num_images, num_landmarks, 2), dtype=int) if hms.shape[1] > 3: # print(hms.max()) for i in range(len(hms)): heatmaps = to_numpy(hms[i]) for l in range(len(heatmaps)): hm = heatmaps[landmark_id_to_heatmap_id(l)] lms[i, l, :] = np.unravel_index(np.argmax(hm, axis=None), hm.shape)[::-1] lm_scale = heatmap_size / target_size return lms / lm_scale
def visualize_images(X, X_lm_hm, landmarks=None, show_recon=True, show_landmarks=True, show_heatmaps=False, draw_wireframe=False, smoothing_level=2, heatmap_opacity=0.8, f=1): if show_recon: disp_X = vis.to_disp_images(X, denorm=True) else: disp_X = vis.to_disp_images(torch.zeros_like(X), denorm=False) heatmap_opacity = 1 if X_lm_hm is not None: if smoothing_level > 0: X_lm_hm = smooth_heatmaps(X_lm_hm) if smoothing_level > 1: X_lm_hm = smooth_heatmaps(X_lm_hm) if show_heatmaps: pred_heatmaps = to_single_channel_heatmap(to_numpy(X_lm_hm)) pred_heatmaps = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_CUBIC) for im in pred_heatmaps ] disp_X = [ vis.overlay_heatmap(disp_X[i], pred_heatmaps[i], heatmap_opacity) for i in range(len(pred_heatmaps)) ] if show_landmarks and landmarks is not None: pred_color = (0, 255, 255) disp_X = vis.add_landmarks_to_images(disp_X, landmarks, color=pred_color, draw_wireframe=draw_wireframe) return disp_X
def draw_results(X_resized, X_recon, levels_z=None, landmarks=None, landmarks_pred=None, cs_errs=None, ncols=15, fx=0.5, fy=0.5, additional_status_text=''): clean_images = True if clean_images: landmarks = None nimgs = len(X_resized) ncols = min(ncols, nimgs) img_size = X_recon.shape[-1] disp_X = vis.to_disp_images(X_resized, denorm=True) disp_X_recon = vis.to_disp_images(X_recon, denorm=True) # reconstruction error in pixels l1_dists = 255.0 * to_numpy( (X_resized - X_recon).abs().reshape(len(disp_X), -1).mean(dim=1)) # SSIM errors ssim = np.zeros(nimgs) for i in range(nimgs): ssim[i] = compare_ssim(disp_X[i], disp_X_recon[i], data_range=1.0, multichannel=True) landmarks = to_numpy(landmarks) cs_errs = to_numpy(cs_errs) text_size = img_size / 256 text_thickness = 2 # # Visualise resized input images and reconstructed images # if landmarks is not None: disp_X = vis.add_landmarks_to_images( disp_X, landmarks, draw_wireframe=False, landmarks_to_draw=lmcfg.LANDMARKS_19) disp_X_recon = vis.add_landmarks_to_images( disp_X_recon, landmarks, draw_wireframe=False, landmarks_to_draw=lmcfg.LANDMARKS_19) if landmarks_pred is not None: disp_X = vis.add_landmarks_to_images(disp_X, landmarks_pred, color=(1, 0, 0)) disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, landmarks_pred, color=(1, 0, 0)) if not clean_images: disp_X_recon = vis.add_error_to_images(disp_X_recon, l1_dists, format_string='{:.1f}', size=text_size, thickness=text_thickness) disp_X_recon = vis.add_error_to_images(disp_X_recon, 1 - ssim, loc='bl-1', format_string='{:>4.2f}', vmax=0.8, vmin=0.2, size=text_size, thickness=text_thickness) if cs_errs is not None: disp_X_recon = vis.add_error_to_images(disp_X_recon, cs_errs, loc='bl-2', format_string='{:>4.2f}', vmax=0.0, vmin=0.4, size=text_size, thickness=text_thickness) # landmark errors lm_errs = np.zeros(1) if landmarks is not None: try: from landmarks import lmutils lm_errs = lmutils.calc_landmark_nme_per_img( landmarks, landmarks_pred) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs, loc='br', format_string='{:>5.2f}', vmax=15, size=img_size / 256, thickness=2) except: pass img_input = vis.make_grid(disp_X, nCols=ncols, normalize=False) img_recon = vis.make_grid(disp_X_recon, nCols=ncols, normalize=False) img_input = cv2.resize(img_input, None, fx=fx, fy=fy, interpolation=cv2.INTER_CUBIC) img_recon = cv2.resize(img_recon, None, fx=fx, fy=fy, interpolation=cv2.INTER_CUBIC) img_stack = [img_input, img_recon] # # Visualise hidden layers # VIS_HIDDEN = False if VIS_HIDDEN: img_z = vis.draw_z_vecs(levels_z, size=(img_size, 30), ncols=ncols) img_z = cv2.resize(img_z, dsize=(img_input.shape[1], img_z.shape[0]), interpolation=cv2.INTER_NEAREST) img_stack.append(img_z) cs_errs_mean = np.mean(cs_errs) if cs_errs is not None else np.nan status_bar_text = ("l1 recon err: {:.2f}px " "ssim: {:.3f}({:.3f}) " "lms err: {:2} {}").format(l1_dists.mean(), cs_errs_mean, 1 - ssim.mean(), lm_errs.mean(), additional_status_text) img_status_bar = vis.draw_status_bar(status_bar_text, status_bar_width=img_input.shape[1], status_bar_height=30, dtype=img_input.dtype) img_stack.append(img_status_bar) return np.vstack(img_stack)
def visualize_batch(self, batch, X_recon, ssim_maps, nimgs=8, ds=None, wait=0): nimgs = min(nimgs, len(batch)) train_state_D = self.saae.D.training train_state_Q = self.saae.Q.training train_state_P = self.saae.P.training self.saae.D.eval() self.saae.Q.eval() self.saae.P.eval() loc_err_gan = 'tr' text_size_errors = 0.65 input_images = vis.to_disp_images(batch.images[:nimgs], denorm=True) target_images = batch.target_images if batch.target_images is not None else batch.images disp_images = vis.to_disp_images(target_images[:nimgs], denorm=True) # draw GAN score if self.args.with_gan: with torch.no_grad(): err_gan_inputs = self.saae.D(batch.images[:nimgs]) disp_images = vis.add_error_to_images(disp_images, errors=1 - err_gan_inputs, loc=loc_err_gan, format_string='{:>5.2f}', vmax=1.0) # disp_images = vis.add_landmarks_to_images(disp_images, batch.landmarks[:nimgs], color=(0,1,0), radius=1, # draw_wireframe=False) rows = [vis.make_grid(disp_images, nCols=nimgs, normalize=False)] recon_images = vis.to_disp_images(X_recon[:nimgs], denorm=True) disp_X_recon = recon_images.copy() print_stats = True if print_stats: # lm_ssim_errs = None # if batch.landmarks is not None: # lm_recon_errs = lmutils.calc_landmark_recon_error(batch.images[:nimgs], X_recon[:nimgs], batch.landmarks[:nimgs], reduction='none') # disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_recon_errs, size=text_size_errors, loc='bm', # format_string='({:>3.1f})', vmin=0, vmax=10) # lm_ssim_errs = lmutils.calc_landmark_ssim_error(batch.images[:nimgs], X_recon[:nimgs], batch.landmarks[:nimgs]) # disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_ssim_errs.mean(axis=1), size=text_size_errors, loc='bm-1', # format_string='({:>3.2f})', vmin=0.2, vmax=0.8) X_recon_errs = 255.0 * torch.abs(batch.images - X_recon).reshape( len(batch.images), -1).mean(dim=1) # disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, batch.landmarks[:nimgs], radius=1, color=None, # lm_errs=lm_ssim_errs, draw_wireframe=False) disp_X_recon = vis.add_error_to_images(disp_X_recon[:nimgs], errors=X_recon_errs, size=text_size_errors, format_string='{:>4.1f}') if self.args.with_gan: with torch.no_grad(): err_gan = self.saae.D(X_recon[:nimgs]) disp_X_recon = vis.add_error_to_images( disp_X_recon, errors=1 - err_gan, loc=loc_err_gan, format_string='{:>5.2f}', vmax=1.0) ssim = np.zeros(nimgs) for i in range(nimgs): data_range = 255.0 if input_images[0].dtype == np.uint8 else 1.0 ssim[i] = compare_ssim(input_images[i], recon_images[i], data_range=data_range, multichannel=True) disp_X_recon = vis.add_error_to_images(disp_X_recon, 1 - ssim, loc='bl-1', size=text_size_errors, format_string='{:>4.2f}', vmin=0.2, vmax=0.8) if ssim_maps is not None: disp_X_recon = vis.add_error_to_images( disp_X_recon, ssim_maps.reshape(len(ssim_maps), -1).mean(axis=1), size=text_size_errors, loc='bl-2', format_string='{:>4.2f}', vmin=0.0, vmax=0.4) rows.append(vis.make_grid(disp_X_recon, nCols=nimgs)) if ssim_maps is not None: disp_ssim_maps = to_numpy( nn.denormalized(ssim_maps)[:nimgs].transpose(0, 2, 3, 1)) for i in range(len(disp_ssim_maps)): disp_ssim_maps[i] = vis.color_map( disp_ssim_maps[i].mean(axis=2), vmin=0.0, vmax=2.0) grid_ssim_maps = vis.make_grid(disp_ssim_maps, nCols=nimgs) cv2.imshow('ssim errors', cv2.cvtColor(grid_ssim_maps, cv2.COLOR_RGB2BGR)) self.saae.D.train(train_state_D) self.saae.Q.train(train_state_Q) self.saae.P.train(train_state_P) f = 1 disp_rows = vis.make_grid(rows, nCols=1, normalize=False, fx=f, fy=f) wnd_title = 'recon errors ' if ds is not None: wnd_title += ds.__class__.__name__ cv2.imshow(wnd_title, cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR)) cv2.waitKey(wait)
def z_vecs(self): return [to_numpy(self.z)]
def detect_landmarks(self, X): X_recon = self.forward(X) X_lm_hm = self.LMH(self.P) X_lm_hm = landmarks.lmutils.smooth_heatmaps(X_lm_hm) lm_preds = to_numpy(self.heatmaps_to_landmarks(X_lm_hm)) return X_recon, lm_preds, X_lm_hm
def calculate_metrics(preds, gt_vessels, fov_masks=None, full_eval=False, verbose=False): assert len(preds) == len(gt_vessels) if not isinstance(preds, np.ndarray): preds = to_numpy(preds) if not isinstance(gt_vessels, np.ndarray): gt_vessels = to_numpy(gt_vessels) assert isinstance(gt_vessels, np.ndarray) if len(gt_vessels.shape) == 2: gt_vessels, preds = gt_vessels[np.newaxis], preds[np.newaxis] if fov_masks is not None: if not isinstance(fov_masks, np.ndarray): fov_masks = np.array(fov_masks) if len(fov_masks.shape) == 2: fov_masks = fov_masks[np.newaxis] gt_vessels_in_mask, pred_vessels_in_mask = pixel_values_in_mask( gt_vessels, preds, fov_masks) else: gt_vessels_in_mask, pred_vessels_in_mask = gt_vessels, preds y_true = to_numpy(gt_vessels_in_mask).ravel() >= 1 y_score = to_numpy(pred_vessels_in_mask).ravel() precision, recall, thresholds = precision_recall_curve(y_true, y_score) precision = np.fliplr([ precision ])[0] # so the array is increasing (you won't get negative AUC) recall = np.fliplr( [recall])[0] # so the array is increasing (you won't get negative AUC) thresholds = np.fliplr([thresholds])[0] AUC_prec_rec = np.trapz(precision, recall) average_precision = AUC_prec_rec results = {} results['PR'] = average_precision if full_eval: best_f1, best_f1_th = best_f1_threshold(precision, recall, thresholds) results['F1'] = best_f1 results['F1_th'] = best_f1_th fpr, tpr, _ = roc_curve(y_true, y_score) roc = auc(fpr, tpr) results['ROC'] = roc otsu_threshold = filters.threshold_otsu(pred_vessels_in_mask) y_pred_bin = pred_vessels_in_mask >= otsu_threshold acc, se, sp, f1 = misc_measures_evaluation(y_true, y_pred_bin) results['otsu_th'] = otsu_threshold results['otsu_SE'] = se results['otsu_SP'] = sp results['otsu_ACC'] = acc results['otsu_F1'] = f1 fixed_threshold = 0.5 y_pred_bin = pred_vessels_in_mask >= fixed_threshold acc, se, sp, f1 = misc_measures_evaluation(y_true, y_pred_bin) results['th_SE'] = se results['th_SP'] = sp results['th_ACC'] = acc results['th_F1'] = f1 if verbose: print(f"F1 score : {best_f1:.4f} (th={best_f1_th:.3f})") print(f"F1 score : {f1:.4f} (th={fixed_threshold:.3f})") print( f"SE/SP/ACC: {se:.4f}, {sp:.4f}, {acc:.4f} (th={fixed_threshold:.3f})" ) print('AUC PR: {0:0.4f}'.format(average_precision)) print('AUC ROC: {0:0.4f}'.format(roc)) return results
def visualize_batch_CVPR(images, landmarks, X_recon, X_lm_hm, lm_preds, show_recon=True, lm_heatmaps=None, ds=None, wait=0, horizontal=False, f=1.0, radius=2, draw_wireframes=False): gt_color = (0, 255, 0) pred_color = (0, 255, 255) # pred_color = (255,0,0) nimgs = min(10, len(images)) images = nn.atleast4d(images)[:nimgs] num_landmarks = lm_preds.shape[1] # if landmarks is None: # print('num landmarks', num_landmarks) # lm_gt = np.zeros((nimgs, num_landmarks, 2)) # else: # show landmark heatmaps pred_heatmaps = None if X_lm_hm is not None: pred_heatmaps = to_single_channel_heatmap(to_numpy(X_lm_hm[:nimgs])) pred_heatmaps = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in pred_heatmaps ] gt_heatmaps = None if lm_heatmaps is not None: gt_heatmaps = to_single_channel_heatmap( to_numpy(lm_heatmaps[:nimgs])) gt_heatmaps = np.array([ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in gt_heatmaps ]) show_landmark_heatmaps(pred_heatmaps, gt_heatmaps, nimgs, f=1) lm_confs = to_numpy(X_lm_hm).reshape(X_lm_hm.shape[0], X_lm_hm.shape[1], -1).max(axis=2) # resize images for display and scale landmarks accordingly lm_preds = lm_preds[:nimgs] * f rows = [] disp_images = vis.to_disp_images(images[:nimgs], denorm=True) disp_images = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in disp_images ] rows.append(vis.make_grid(disp_images, nCols=nimgs, normalize=False)) heatmap_opacity = 1.0 if show_recon: recon_images = vis.to_disp_images(X_recon[:nimgs], denorm=True) else: recon_images = vis.to_disp_images(torch.ones_like(X_recon[:nimgs]), denorm=False) heatmap_opacity = 1 disp_X_recon = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in recon_images.copy() ] rows.append(vis.make_grid(disp_X_recon, nCols=nimgs)) # overlay landmarks on images disp_X_recon_hm = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in recon_images.copy() ] disp_X_recon_hm = [ vis.overlay_heatmap(disp_X_recon_hm[i], pred_heatmaps[i], heatmap_opacity) for i in range(len(pred_heatmaps)) ] rows.append(vis.make_grid(disp_X_recon_hm, nCols=nimgs)) # reconstructions with prediction disp_X_recon_pred = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in recon_images.copy() ] disp_X_recon_pred = vis.add_landmarks_to_images(disp_X_recon_pred, lm_preds, color=pred_color, radius=radius) rows.append(vis.make_grid(disp_X_recon_pred, nCols=nimgs)) # reconstructions with ground truth (if gt available) if landmarks is not None: lm_gt = nn.atleast3d(to_numpy(landmarks))[:nimgs] * f disp_X_recon_gt = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in recon_images.copy() ] disp_X_recon_gt = vis.add_landmarks_to_images(disp_X_recon_gt, lm_gt, color=gt_color, radius=radius) rows.append(vis.make_grid(disp_X_recon_gt, nCols=nimgs)) # input images with prediction (and ground truth) disp_images_pred = vis.to_disp_images(images[:nimgs], denorm=True) disp_images_pred = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in disp_images_pred ] # disp_images_pred = vis.add_landmarks_to_images(disp_images_pred, lm_gt, color=gt_color, radius=radius) disp_images_pred = vis.add_landmarks_to_images( disp_images_pred, lm_preds, color=pred_color, radius=radius, draw_wireframe=draw_wireframes) rows.append(vis.make_grid(disp_images_pred, nCols=nimgs)) if horizontal: assert (nimgs == 1) disp_rows = vis.make_grid(rows, nCols=len(rows)) else: disp_rows = vis.make_grid(rows, nCols=1) wnd_title = 'recon errors ' if ds is not None: wnd_title += ds.__class__.__name__ cv2.imshow(wnd_title, cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR)) cv2.waitKey(wait)
raise ValueError net.eval() results_probs = [] gt_masks = [] fov_masks = [] results = [] t_tot = 0 for idx in range(len(dataset))[:]: data = dataset[idx] image_id = data['fname'] full_image = to_numpy(data['image']) gt_mask = to_numpy(data['mask']) // 255 fov_mask = dataset.fov_masks[image_id] print(f'\n---- Testing image {idx+1}: {image_id} ---- ') t = time.perf_counter() recon, probs = predict_vessels.segment_image( net, full_image, patch_size=args.patch_size, scales=scales[dsname]) #, gpu=args.gpu) # probs_lr = np.fliplr(predict_vessels.segment_image(net, np.fliplr(full_image), scales=scales[dsname], patch_size=args.patch_size)[1]) # probs_ud = np.flipud(predict_vessels.segment_image(net, np.flipud(full_image), scales=scales[dsname], patch_size=args.patch_size)[1]) # probs = (probs + probs_lr + probs_ud) / 3 t_image = time.perf_counter() - t
torch.manual_seed(0) torch.cuda.manual_seed_all(0) dirs = config.get_dataset_paths('affectnet') train = True ds = AffectNet(root=dirs[0], image_size=256, cache_root=dirs[1], train=train, use_cache=False, transform=ds_utils.build_transform(deterministic=not train, daug=0), crop_source='lm_ground_truth') dl = td.DataLoader(ds, batch_size=10, shuffle=False, num_workers=0) # print(ds) for data in dl: batch = Batch(data, gpu=False) gt = to_numpy(batch.landmarks) ocular_dists_inner = np.sqrt(np.sum((gt[:, 42] - gt[:, 39])**2, axis=1)) ocular_dists_outer = np.sqrt(np.sum((gt[:, 45] - gt[:, 36])**2, axis=1)) ocular_dists = np.vstack( (ocular_dists_inner, ocular_dists_outer)).mean(axis=0) print(ocular_dists) images = vis.to_disp_images(batch.images, denorm=True) imgs = vis.add_landmarks_to_images(images, batch.landmarks.numpy()) vis.vis_square(imgs, nCols=10, fx=1.0, fy=1.0, normalize=False)
def visualize_batch(images, landmarks, X_recon, X_lm_hm, lm_preds_max, lm_heatmaps=None, target_images=None, lm_preds_cnn=None, ds=None, wait=0, ssim_maps=None, landmarks_to_draw=None, ocular_norm='outer', horizontal=False, f=1.0, overlay_heatmaps_input=False, overlay_heatmaps_recon=False, clean=False, landmarks_only_outline=range(17), landmarks_no_outline=range(17, 68)): gt_color = (0, 255, 0) pred_color = (0, 0, 255) image_size = images.shape[3] assert image_size in [128, 256] nimgs = min(10, len(images)) images = nn.atleast4d(images)[:nimgs] num_landmarks = lm_preds_max.shape[1] if landmarks_to_draw is None: landmarks_to_draw = range(num_landmarks) nme_per_lm = None if landmarks is None: # print('num landmarks', lmcfg.NUM_LANDMARKS) lm_gt = np.zeros((nimgs, num_landmarks, 2)) else: lm_gt = nn.atleast3d(to_numpy(landmarks))[:nimgs] nme_per_lm = calc_landmark_nme(lm_gt, lm_preds_max[:nimgs], ocular_norm=ocular_norm, image_size=image_size) lm_ssim_errs = 1 - calc_landmark_ssim_score(images, X_recon[:nimgs], lm_gt) lm_confs = None # show landmark heatmaps pred_heatmaps = None if X_lm_hm is not None: pred_heatmaps = to_single_channel_heatmap(to_numpy(X_lm_hm[:nimgs])) pred_heatmaps = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in pred_heatmaps ] gt_heatmaps = None if lm_heatmaps is not None: gt_heatmaps = to_single_channel_heatmap( to_numpy(lm_heatmaps[:nimgs])) gt_heatmaps = np.array([ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in gt_heatmaps ]) show_landmark_heatmaps(pred_heatmaps, gt_heatmaps, nimgs, f=1) lm_confs = to_numpy(X_lm_hm).reshape(X_lm_hm.shape[0], X_lm_hm.shape[1], -1).max(axis=2) # resize images for display and scale landmarks accordingly lm_preds_max = lm_preds_max[:nimgs] * f if lm_preds_cnn is not None: lm_preds_cnn = lm_preds_cnn[:nimgs] * f lm_gt *= f input_images = vis.to_disp_images(images[:nimgs], denorm=True) if target_images is not None: disp_images = vis.to_disp_images(target_images[:nimgs], denorm=True) else: disp_images = vis.to_disp_images(images[:nimgs], denorm=True) disp_images = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in disp_images ] recon_images = vis.to_disp_images(X_recon[:nimgs], denorm=True) disp_X_recon = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in recon_images.copy() ] # overlay landmarks on input images if pred_heatmaps is not None and overlay_heatmaps_input: disp_images = [ vis.overlay_heatmap(disp_images[i], pred_heatmaps[i]) for i in range(len(pred_heatmaps)) ] if pred_heatmaps is not None and overlay_heatmaps_recon: disp_X_recon = [ vis.overlay_heatmap(disp_X_recon[i], pred_heatmaps[i]) for i in range(len(pred_heatmaps)) ] # # Show input images # disp_images = vis.add_landmarks_to_images(disp_images, lm_gt[:nimgs], color=gt_color) disp_images = vis.add_landmarks_to_images(disp_images, lm_preds_max[:nimgs], lm_errs=nme_per_lm, color=pred_color, draw_wireframe=False, gt_landmarks=lm_gt, draw_gt_offsets=True) # disp_images = vis.add_landmarks_to_images(disp_images, lm_gt[:nimgs], color=(1,1,1), radius=1, # draw_dots=True, draw_wireframe=True, landmarks_to_draw=landmarks_to_draw) # disp_images = vis.add_landmarks_to_images(disp_images, lm_preds_max[:nimgs], lm_errs=nme_per_lm, # color=(1.0, 0.0, 0.0), # draw_dots=True, draw_wireframe=True, radius=1, # gt_landmarks=lm_gt, draw_gt_offsets=False, # landmarks_to_draw=landmarks_to_draw) # # Show reconstructions # X_recon_errs = 255.0 * torch.abs(images - X_recon[:nimgs]).reshape( len(images), -1).mean(dim=1) if not clean: disp_X_recon = vis.add_error_to_images(disp_X_recon[:nimgs], errors=X_recon_errs, format_string='{:>4.1f}') # modes of heatmaps # disp_X_recon = [overlay_heatmap(disp_X_recon[i], pred_heatmaps[i]) for i in range(len(pred_heatmaps))] if not clean: lm_errs_max = calc_landmark_nme_per_img(lm_gt, lm_preds_max, ocular_norm, landmarks_no_outline, image_size=image_size) lm_errs_max_outline = calc_landmark_nme_per_img(lm_gt, lm_preds_max, ocular_norm, landmarks_only_outline, image_size=image_size) lm_errs_max_all = calc_landmark_nme_per_img( lm_gt, lm_preds_max, ocular_norm, list(landmarks_only_outline) + list(landmarks_no_outline), image_size=image_size) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs_max, loc='br-2', format_string='{:>5.2f}', vmax=15) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs_max_outline, loc='br-1', format_string='{:>5.2f}', vmax=15) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs_max_all, loc='br', format_string='{:>5.2f}', vmax=15) disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, lm_gt, color=gt_color, draw_wireframe=True) # disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, lm_preds_max[:nimgs], # color=pred_color, draw_wireframe=False, # lm_errs=nme_per_lm, lm_confs=lm_confs, # lm_rec_errs=lm_ssim_errs, gt_landmarks=lm_gt, # draw_gt_offsets=True, draw_dots=True) disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, lm_preds_max[:nimgs], color=pred_color, draw_wireframe=True, gt_landmarks=lm_gt, draw_gt_offsets=True, lm_errs=nme_per_lm, draw_dots=True, radius=2) def add_confidences(disp_X_recon, lmids, loc): means = lm_confs[:, lmids].mean(axis=1) colors = vis.color_map(to_numpy(1 - means), cmap=plt.cm.jet, vmin=0.0, vmax=0.4) return vis.add_error_to_images(disp_X_recon, means, loc=loc, format_string='{:>4.2f}', colors=colors) # disp_X_recon = add_confidences(disp_X_recon, lmcfg.LANDMARKS_NO_OUTLINE, 'bm-2') # disp_X_recon = add_confidences(disp_X_recon, lmcfg.LANDMARKS_ONLY_OUTLINE, 'bm-1') # disp_X_recon = add_confidences(disp_X_recon, lmcfg.ALL_LANDMARKS, 'bm') # print ssim errors ssim = np.zeros(nimgs) for i in range(nimgs): ssim[i] = compare_ssim(input_images[i], recon_images[i], data_range=1.0, multichannel=True) if not clean: disp_X_recon = vis.add_error_to_images(disp_X_recon, 1 - ssim, loc='bl-1', format_string='{:>4.2f}', vmax=0.8, vmin=0.2) # print ssim torch errors if ssim_maps is not None and not clean: disp_X_recon = vis.add_error_to_images(disp_X_recon, ssim_maps.reshape( len(ssim_maps), -1).mean(axis=1), loc='bl-2', format_string='{:>4.2f}', vmin=0.0, vmax=0.4) rows = [vis.make_grid(disp_images, nCols=nimgs, normalize=False)] rows.append(vis.make_grid(disp_X_recon, nCols=nimgs)) if ssim_maps is not None: disp_ssim_maps = to_numpy( nn.denormalized(ssim_maps)[:nimgs].transpose(0, 2, 3, 1)) for i in range(len(disp_ssim_maps)): disp_ssim_maps[i] = vis.color_map(disp_ssim_maps[i].mean(axis=2), vmin=0.0, vmax=2.0) grid_ssim_maps = vis.make_grid(disp_ssim_maps, nCols=nimgs, fx=f, fy=f) cv2.imshow('ssim errors', cv2.cvtColor(grid_ssim_maps, cv2.COLOR_RGB2BGR)) if horizontal: assert (nimgs == 1) disp_rows = vis.make_grid(rows, nCols=2) else: disp_rows = vis.make_grid(rows, nCols=1) wnd_title = 'Predicted Landmarks ' if ds is not None: wnd_title += ds.__class__.__name__ cv2.imshow(wnd_title, cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR)) cv2.waitKey(wait)
def visualize_vessels(images, X_recon, vessel_hm, pred_vessel_hm=None, ds=None, wait=0, horizontal=False, f=1.0, overlay_heatmaps_input=True, overlay_heatmaps_recon=True, scores=None, nimgs=5): nimgs = min(nimgs, len(images)) images = images[:nimgs] rows = [] input_images = vis.to_disp_images(images[:nimgs], denorm=True) disp_images = vis.to_disp_images(images[:nimgs], denorm=True) disp_images = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in disp_images ] rows.append(vis.make_grid(disp_images, nCols=nimgs, normalize=False)) recon_images = vis.to_disp_images(X_recon[:nimgs], denorm=True) disp_X_recon = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in recon_images.copy() ] if vessel_hm is not None and overlay_heatmaps_input: vessel_hm = to_numpy(vessel_hm[:nimgs]) disp_images = [ vis.overlay_heatmap(disp_images[i], vessel_hm[i, 0], 0.5) for i in range(len(vessel_hm)) ] rows.append(vis.make_grid(disp_images, nCols=nimgs, normalize=False)) if pred_vessel_hm is not None and overlay_heatmaps_recon: pred_vessel_hm = to_numpy(pred_vessel_hm[:nimgs]) disp_X_recon_overlay = [ vis.overlay_heatmap(disp_X_recon[i], pred_vessel_hm[i, 0], 1.0) for i in range(len(pred_vessel_hm)) ] if scores is not None: disp_X_recon_overlay = vis.add_error_to_images( disp_X_recon_overlay, scores, loc='tr', format_string='{:.3f}') rows.append(vis.make_grid(disp_X_recon_overlay, nCols=nimgs)) rows.append(vis.make_grid(disp_X_recon, nCols=nimgs)) if horizontal: assert (nimgs == 1) disp_rows = vis.make_grid(rows, nCols=4) else: disp_rows = vis.make_grid(rows, nCols=1) wnd_title = 'Predicted vessels ' if ds is not None: wnd_title += ds.__class__.__name__ cv2.imshow(wnd_title, cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR)) cv2.waitKey(wait)
def add_landmarks_to_images(images, landmarks, color=None, radius=2, gt_landmarks=None, lm_errs=None, lm_confs=None, lm_rec_errs=None, draw_dots=True, draw_wireframe=False, draw_gt_offsets=False, landmarks_to_draw=None, offset_line_color=None): def draw_wireframe_lines(img, lms): pts = lms.reshape((-1,1,2)).astype(np.int32) cv2.polylines(img, [pts[:17]], isClosed=False, color=color, lineType=cv2.LINE_AA) # head outline cv2.polylines(img, [pts[17:22]], isClosed=False, color=color, lineType=cv2.LINE_AA) # left eyebrow cv2.polylines(img, [pts[22:27]], isClosed=False, color=color, lineType=cv2.LINE_AA) # right eyebrow cv2.polylines(img, [pts[27:31]], isClosed=False, color=color, lineType=cv2.LINE_AA) # nose vert cv2.polylines(img, [pts[31:36]], isClosed=False, color=color, lineType=cv2.LINE_AA) # nose hor cv2.polylines(img, [pts[36:42]], isClosed=True, color=color, lineType=cv2.LINE_AA) # left eye cv2.polylines(img, [pts[42:48]], isClosed=True, color=color, lineType=cv2.LINE_AA) # right eye cv2.polylines(img, [pts[48:60]], isClosed=True, color=color, lineType=cv2.LINE_AA) # outer mouth cv2.polylines(img, [pts[60:68]], isClosed=True, color=color, lineType=cv2.LINE_AA) # inner mouth def draw_wireframe_lines_98(img, lms): pts = lms.reshape((-1,1,2)).astype(np.int32) cv2.polylines(img, [pts[:33]], isClosed=False, color=color, lineType=cv2.LINE_AA) # head outline cv2.polylines(img, [pts[33:42]], isClosed=True, color=color, lineType=cv2.LINE_AA) # left eyebrow # cv2.polylines(img, [pts[38:42]], isClosed=False, color=color, lineType=cv2.LINE_AA) # right eyebrow cv2.polylines(img, [pts[42:51]], isClosed=True, color=color, lineType=cv2.LINE_AA) # nose vert cv2.polylines(img, [pts[51:55]], isClosed=False, color=color, lineType=cv2.LINE_AA) # nose hor cv2.polylines(img, [pts[55:60]], isClosed=False, color=color, lineType=cv2.LINE_AA) # left eye cv2.polylines(img, [pts[60:68]], isClosed=True, color=color, lineType=cv2.LINE_AA) # right eye cv2.polylines(img, [pts[68:76]], isClosed=True, color=color, lineType=cv2.LINE_AA) # outer mouth cv2.polylines(img, [pts[76:88]], isClosed=True, color=color, lineType=cv2.LINE_AA) # inner mouth cv2.polylines(img, [pts[88:96]], isClosed=True, color=color, lineType=cv2.LINE_AA) # inner mouth def draw_offset_lines(img, lms, gt_lms, errs): if gt_lms.sum() == 0: return if lm_errs is None: # if offset_line_color is None: offset_line_color = (1,1,1) colors = [offset_line_color] * len(lms) else: colors = color_map(errs, cmap=plt.cm.jet, vmin=0, vmax=15.0) if img.dtype == np.uint8: colors *= 255 for i, (p1, p2) in enumerate(zip(lms, gt_lms)): if landmarks_to_draw is None or i in landmarks_to_draw: if p1.min() > 0: cv2.line(img, tuple(p1.astype(int)), tuple(p2.astype(int)), colors[i], thickness=1, lineType=cv2.LINE_AA) new_images = to_disp_images(images) landmarks = to_numpy(landmarks) gt_landmarks = to_numpy(gt_landmarks) lm_errs = to_numpy(lm_errs) img_size = new_images[0].shape[0] default_color = (255,255,255) if gt_landmarks is not None and draw_gt_offsets: for img_id in range(len(new_images)): if gt_landmarks[img_id].sum() == 0: continue dists = None if lm_errs is not None: dists = lm_errs[img_id] draw_offset_lines(new_images[img_id], landmarks[img_id], gt_landmarks[img_id], dists) for img_id, (disp, lm) in enumerate(zip(new_images, landmarks)): if len(lm) in [68, 21, 19, 98, 8, 5, 38]: if draw_dots: for lm_id in range(0,len(lm)): if landmarks_to_draw is None or lm_id in landmarks_to_draw or len(lm) != 68: lm_color = color if lm_color is None: if lm_errs is not None: lm_color = color_map(lm_errs[img_id, lm_id], cmap=plt.cm.jet, vmin=0, vmax=1.0) else: lm_color = default_color # if lm_errs is not None and lm_errs[img_id, lm_id] > 40.0: # lm_color = (1,0,0) cv2.circle(disp, tuple(lm[lm_id].astype(int).clip(0, disp.shape[0]-1)), radius=radius, color=lm_color, thickness=-1, lineType=cv2.LINE_AA) if lm_confs is not None: max_radius = img_size * 0.05 try: conf_radius = max(2, int((1-lm_confs[img_id, lm_id]) * max_radius)) except ValueError: conf_radius = 2 # if lm_confs[img_id, lm_id] > 0.4: cirle_color = (0,0,255) # if lm_confs[img_id, lm_id] < is_good_landmark(lm_confs, lm_rec_errs): # if not is_good_landmark(lm_confs[img_id, lm_id], lm_rec_errs[img_id, lm_id]): if lm_errs[img_id, lm_id] > 10.0: cirle_color = (255,0,0) cv2.circle(disp, tuple(lm[lm_id].astype(int)), conf_radius, cirle_color, 1, lineType=cv2.LINE_AA) # Draw outline if we actually have 68 valid landmarks. # Landmarks can be zeros for UMD landmark format (21 points). if draw_wireframe: nlms = (np.count_nonzero(lm.sum(axis=1))) if nlms == 68: draw_wireframe_lines(disp, lm) elif nlms == 98: draw_wireframe_lines_98(disp, lm) else: # colors = ['tab:gray', 'tab:orange', 'tab:brown', 'tab:pink', 'tab:cyan', 'tab:olive', 'tab:red', 'tab:blue'] # colors_rgb = list(map(plt_colors.to_rgb, colors)) colors = sns.color_palette("Set1", n_colors=14) for i in range(0,len(lm)): cv2.circle(disp, tuple(lm[i].astype(int)), radius=radius, color=colors[i], thickness=2, lineType=cv2.LINE_AA) return new_images
def get_landmark_confs(X_lm_hm): return np.clip(to_numpy(X_lm_hm).reshape(X_lm_hm.shape[0], X_lm_hm.shape[1], -1).max(axis=2), a_min=0, a_max=1)
def reformat(lms): lms = to_numpy(lms) if len(lms.shape) == 2: lms = lms.reshape((1, -1, 2)) return lms
def _run_batch(self, data, eval=False, ds=None): time_dataloading = time.time() - self.iter_starttime time_proc_start = time.time() iter_stats = {'time_dataloading': time_dataloading} batch = Batch(data, eval=eval) self.saae.zero_grad() self.saae.eval() input_images = batch.target_images if batch.target_images is not None else batch.images with torch.set_grad_enabled(self.args.train_encoder): z_sample = self.saae.Q(input_images) iter_stats.update({'z_recon_mean': z_sample.mean().item()}) ####################### # Reconstruction phase ####################### with torch.set_grad_enabled(self.args.train_encoder and not eval): X_recon = self.saae.P(z_sample) # calculate reconstruction error for debugging and reporting with torch.no_grad(): iter_stats['loss_recon'] = aae_training.loss_recon( batch.images, X_recon) ####################### # Landmark predictions ####################### train_lmhead = not eval lm_preds_max = None with torch.set_grad_enabled(train_lmhead): self.saae.LMH.train(train_lmhead) X_lm_hm = self.saae.LMH(self.saae.P) if batch.lm_heatmaps is not None: loss_lms = F.mse_loss(batch.lm_heatmaps, X_lm_hm) * 100 * 3 iter_stats.update({'loss_lms': loss_lms.item()}) if eval or self._is_printout_iter(eval): # expensive, so only calculate when every N iterations # X_lm_hm = lmutils.decode_heatmap_blob(X_lm_hm) X_lm_hm = lmutils.smooth_heatmaps(X_lm_hm) lm_preds_max = self.saae.heatmaps_to_landmarks(X_lm_hm) if eval or self._is_printout_iter(eval): lm_gt = to_numpy(batch.landmarks) nmes = lmutils.calc_landmark_nme( lm_gt, lm_preds_max, ocular_norm=self.args.ocular_norm, image_size=self.args.input_size) # nccs = lmutils.calc_landmark_ncc(batch.images, X_recon, lm_gt) iter_stats.update({'nmes': nmes}) if train_lmhead: # if self.args.train_encoder: # loss_lms = loss_lms * 80.0 loss_lms.backward() self.optimizer_lm_head.step() if self.args.train_encoder: self.optimizer_E.step() # self.optimizer_G.step() # statistics iter_stats.update({ 'epoch': self.epoch, 'timestamp': time.time(), 'iter_time': time.time() - self.iter_starttime, 'time_processing': time.time() - time_proc_start, 'iter': self.iter_in_epoch, 'total_iter': self.total_iter, 'batch_size': len(batch) }) self.iter_starttime = time.time() self.epoch_stats.append(iter_stats) # print stats every N mini-batches if self._is_printout_iter(eval): self._print_iter_stats( self.epoch_stats[-self._print_interval(eval):]) lmvis.visualize_batch( batch.images, batch.landmarks, X_recon, X_lm_hm, lm_preds_max, lm_heatmaps=batch.lm_heatmaps, target_images=batch.target_images, ds=ds, ocular_norm=self.args.ocular_norm, clean=False, overlay_heatmaps_input=False, overlay_heatmaps_recon=False, landmarks_only_outline=self.landmarks_only_outline, landmarks_no_outline=self.landmarks_no_outline, f=1.0, wait=self.wait)