def show_landmarks(img_in, landmarks, bbox=None, gt=None, title='landmarks', pose=None, wait=10, color=(1, 0, 0)): from landmarks.lmutils import calc_landmark_nme_per_img img = img_in.copy() if img.max() > 1.01: img = img.astype(np.float32)/255.0 if bbox is not None: tl = tuple([int(v) for v in bbox[:2]]) br = tuple([int(v) for v in bbox[2:]]) cv2.rectangle(img, tl, br, (1,1,1)) if gt is not None: for lm in gt: lm_x, lm_y = lm[0], lm[1] cv2.circle(img, (int(lm_x), int(lm_y)), 1, (1, 1, 0), -1, lineType=cv2.LINE_AA) for lm in landmarks: lm_x, lm_y = lm[0], lm[1] cv2.circle(img, (int(lm_x), int(lm_y)), 2, color, -1, lineType=cv2.LINE_AA) if pose is not None: from utils import vis3d vis3d.draw_head_pose(img, pose, color=(1.0,1.0,1.0)) # if img.shape[0] > 800: # img = cv2.resize(img, dsize=None, fx=0.5, fy=0.5) if gt is not None: lm_err = calc_landmark_nme_per_img(gt, landmarks, ocular_norm='outer') img = add_error_to_images([img], lm_err)[0] cv2.imshow(title, cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) cv2.waitKey(wait)
def draw_results(X_resized, X_recon, levels_z=None, landmarks=None, landmarks_pred=None, cs_errs=None, ncols=15, fx=0.5, fy=0.5, additional_status_text=''): clean_images = True if clean_images: landmarks = None nimgs = len(X_resized) ncols = min(ncols, nimgs) img_size = X_recon.shape[-1] disp_X = vis.to_disp_images(X_resized, denorm=True) disp_X_recon = vis.to_disp_images(X_recon, denorm=True) # reconstruction error in pixels l1_dists = 255.0 * to_numpy( (X_resized - X_recon).abs().reshape(len(disp_X), -1).mean(dim=1)) # SSIM errors ssim = np.zeros(nimgs) for i in range(nimgs): ssim[i] = compare_ssim(disp_X[i], disp_X_recon[i], data_range=1.0, multichannel=True) landmarks = to_numpy(landmarks) cs_errs = to_numpy(cs_errs) text_size = img_size / 256 text_thickness = 2 # # Visualise resized input images and reconstructed images # if landmarks is not None: disp_X = vis.add_landmarks_to_images( disp_X, landmarks, draw_wireframe=False, landmarks_to_draw=lmcfg.LANDMARKS_19) disp_X_recon = vis.add_landmarks_to_images( disp_X_recon, landmarks, draw_wireframe=False, landmarks_to_draw=lmcfg.LANDMARKS_19) if landmarks_pred is not None: disp_X = vis.add_landmarks_to_images(disp_X, landmarks_pred, color=(1, 0, 0)) disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, landmarks_pred, color=(1, 0, 0)) if not clean_images: disp_X_recon = vis.add_error_to_images(disp_X_recon, l1_dists, format_string='{:.1f}', size=text_size, thickness=text_thickness) disp_X_recon = vis.add_error_to_images(disp_X_recon, 1 - ssim, loc='bl-1', format_string='{:>4.2f}', vmax=0.8, vmin=0.2, size=text_size, thickness=text_thickness) if cs_errs is not None: disp_X_recon = vis.add_error_to_images(disp_X_recon, cs_errs, loc='bl-2', format_string='{:>4.2f}', vmax=0.0, vmin=0.4, size=text_size, thickness=text_thickness) # landmark errors lm_errs = np.zeros(1) if landmarks is not None: try: from landmarks import lmutils lm_errs = lmutils.calc_landmark_nme_per_img( landmarks, landmarks_pred) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs, loc='br', format_string='{:>5.2f}', vmax=15, size=img_size / 256, thickness=2) except: pass img_input = vis.make_grid(disp_X, nCols=ncols, normalize=False) img_recon = vis.make_grid(disp_X_recon, nCols=ncols, normalize=False) img_input = cv2.resize(img_input, None, fx=fx, fy=fy, interpolation=cv2.INTER_CUBIC) img_recon = cv2.resize(img_recon, None, fx=fx, fy=fy, interpolation=cv2.INTER_CUBIC) img_stack = [img_input, img_recon] # # Visualise hidden layers # VIS_HIDDEN = False if VIS_HIDDEN: img_z = vis.draw_z_vecs(levels_z, size=(img_size, 30), ncols=ncols) img_z = cv2.resize(img_z, dsize=(img_input.shape[1], img_z.shape[0]), interpolation=cv2.INTER_NEAREST) img_stack.append(img_z) cs_errs_mean = np.mean(cs_errs) if cs_errs is not None else np.nan status_bar_text = ("l1 recon err: {:.2f}px " "ssim: {:.3f}({:.3f}) " "lms err: {:2} {}").format(l1_dists.mean(), cs_errs_mean, 1 - ssim.mean(), lm_errs.mean(), additional_status_text) img_status_bar = vis.draw_status_bar(status_bar_text, status_bar_width=img_input.shape[1], status_bar_height=30, dtype=img_input.dtype) img_stack.append(img_status_bar) return np.vstack(img_stack)
def visualize_batch(images, landmarks, X_recon, X_lm_hm, lm_preds_max, lm_heatmaps=None, target_images=None, lm_preds_cnn=None, ds=None, wait=0, ssim_maps=None, landmarks_to_draw=None, ocular_norm='outer', horizontal=False, f=1.0, overlay_heatmaps_input=False, overlay_heatmaps_recon=False, clean=False, landmarks_only_outline=range(17), landmarks_no_outline=range(17, 68)): gt_color = (0, 255, 0) pred_color = (0, 0, 255) image_size = images.shape[3] assert image_size in [128, 256] nimgs = min(10, len(images)) images = nn.atleast4d(images)[:nimgs] num_landmarks = lm_preds_max.shape[1] if landmarks_to_draw is None: landmarks_to_draw = range(num_landmarks) nme_per_lm = None if landmarks is None: # print('num landmarks', lmcfg.NUM_LANDMARKS) lm_gt = np.zeros((nimgs, num_landmarks, 2)) else: lm_gt = nn.atleast3d(to_numpy(landmarks))[:nimgs] nme_per_lm = calc_landmark_nme(lm_gt, lm_preds_max[:nimgs], ocular_norm=ocular_norm, image_size=image_size) lm_ssim_errs = 1 - calc_landmark_ssim_score(images, X_recon[:nimgs], lm_gt) lm_confs = None # show landmark heatmaps pred_heatmaps = None if X_lm_hm is not None: pred_heatmaps = to_single_channel_heatmap(to_numpy(X_lm_hm[:nimgs])) pred_heatmaps = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in pred_heatmaps ] gt_heatmaps = None if lm_heatmaps is not None: gt_heatmaps = to_single_channel_heatmap( to_numpy(lm_heatmaps[:nimgs])) gt_heatmaps = np.array([ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in gt_heatmaps ]) show_landmark_heatmaps(pred_heatmaps, gt_heatmaps, nimgs, f=1) lm_confs = to_numpy(X_lm_hm).reshape(X_lm_hm.shape[0], X_lm_hm.shape[1], -1).max(axis=2) # resize images for display and scale landmarks accordingly lm_preds_max = lm_preds_max[:nimgs] * f if lm_preds_cnn is not None: lm_preds_cnn = lm_preds_cnn[:nimgs] * f lm_gt *= f input_images = vis.to_disp_images(images[:nimgs], denorm=True) if target_images is not None: disp_images = vis.to_disp_images(target_images[:nimgs], denorm=True) else: disp_images = vis.to_disp_images(images[:nimgs], denorm=True) disp_images = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in disp_images ] recon_images = vis.to_disp_images(X_recon[:nimgs], denorm=True) disp_X_recon = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in recon_images.copy() ] # overlay landmarks on input images if pred_heatmaps is not None and overlay_heatmaps_input: disp_images = [ vis.overlay_heatmap(disp_images[i], pred_heatmaps[i]) for i in range(len(pred_heatmaps)) ] if pred_heatmaps is not None and overlay_heatmaps_recon: disp_X_recon = [ vis.overlay_heatmap(disp_X_recon[i], pred_heatmaps[i]) for i in range(len(pred_heatmaps)) ] # # Show input images # disp_images = vis.add_landmarks_to_images(disp_images, lm_gt[:nimgs], color=gt_color) disp_images = vis.add_landmarks_to_images(disp_images, lm_preds_max[:nimgs], lm_errs=nme_per_lm, color=pred_color, draw_wireframe=False, gt_landmarks=lm_gt, draw_gt_offsets=True) # disp_images = vis.add_landmarks_to_images(disp_images, lm_gt[:nimgs], color=(1,1,1), radius=1, # draw_dots=True, draw_wireframe=True, landmarks_to_draw=landmarks_to_draw) # disp_images = vis.add_landmarks_to_images(disp_images, lm_preds_max[:nimgs], lm_errs=nme_per_lm, # color=(1.0, 0.0, 0.0), # draw_dots=True, draw_wireframe=True, radius=1, # gt_landmarks=lm_gt, draw_gt_offsets=False, # landmarks_to_draw=landmarks_to_draw) # # Show reconstructions # X_recon_errs = 255.0 * torch.abs(images - X_recon[:nimgs]).reshape( len(images), -1).mean(dim=1) if not clean: disp_X_recon = vis.add_error_to_images(disp_X_recon[:nimgs], errors=X_recon_errs, format_string='{:>4.1f}') # modes of heatmaps # disp_X_recon = [overlay_heatmap(disp_X_recon[i], pred_heatmaps[i]) for i in range(len(pred_heatmaps))] if not clean: lm_errs_max = calc_landmark_nme_per_img(lm_gt, lm_preds_max, ocular_norm, landmarks_no_outline, image_size=image_size) lm_errs_max_outline = calc_landmark_nme_per_img(lm_gt, lm_preds_max, ocular_norm, landmarks_only_outline, image_size=image_size) lm_errs_max_all = calc_landmark_nme_per_img( lm_gt, lm_preds_max, ocular_norm, list(landmarks_only_outline) + list(landmarks_no_outline), image_size=image_size) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs_max, loc='br-2', format_string='{:>5.2f}', vmax=15) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs_max_outline, loc='br-1', format_string='{:>5.2f}', vmax=15) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs_max_all, loc='br', format_string='{:>5.2f}', vmax=15) disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, lm_gt, color=gt_color, draw_wireframe=True) # disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, lm_preds_max[:nimgs], # color=pred_color, draw_wireframe=False, # lm_errs=nme_per_lm, lm_confs=lm_confs, # lm_rec_errs=lm_ssim_errs, gt_landmarks=lm_gt, # draw_gt_offsets=True, draw_dots=True) disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, lm_preds_max[:nimgs], color=pred_color, draw_wireframe=True, gt_landmarks=lm_gt, draw_gt_offsets=True, lm_errs=nme_per_lm, draw_dots=True, radius=2) def add_confidences(disp_X_recon, lmids, loc): means = lm_confs[:, lmids].mean(axis=1) colors = vis.color_map(to_numpy(1 - means), cmap=plt.cm.jet, vmin=0.0, vmax=0.4) return vis.add_error_to_images(disp_X_recon, means, loc=loc, format_string='{:>4.2f}', colors=colors) # disp_X_recon = add_confidences(disp_X_recon, lmcfg.LANDMARKS_NO_OUTLINE, 'bm-2') # disp_X_recon = add_confidences(disp_X_recon, lmcfg.LANDMARKS_ONLY_OUTLINE, 'bm-1') # disp_X_recon = add_confidences(disp_X_recon, lmcfg.ALL_LANDMARKS, 'bm') # print ssim errors ssim = np.zeros(nimgs) for i in range(nimgs): ssim[i] = compare_ssim(input_images[i], recon_images[i], data_range=1.0, multichannel=True) if not clean: disp_X_recon = vis.add_error_to_images(disp_X_recon, 1 - ssim, loc='bl-1', format_string='{:>4.2f}', vmax=0.8, vmin=0.2) # print ssim torch errors if ssim_maps is not None and not clean: disp_X_recon = vis.add_error_to_images(disp_X_recon, ssim_maps.reshape( len(ssim_maps), -1).mean(axis=1), loc='bl-2', format_string='{:>4.2f}', vmin=0.0, vmax=0.4) rows = [vis.make_grid(disp_images, nCols=nimgs, normalize=False)] rows.append(vis.make_grid(disp_X_recon, nCols=nimgs)) if ssim_maps is not None: disp_ssim_maps = to_numpy( nn.denormalized(ssim_maps)[:nimgs].transpose(0, 2, 3, 1)) for i in range(len(disp_ssim_maps)): disp_ssim_maps[i] = vis.color_map(disp_ssim_maps[i].mean(axis=2), vmin=0.0, vmax=2.0) grid_ssim_maps = vis.make_grid(disp_ssim_maps, nCols=nimgs, fx=f, fy=f) cv2.imshow('ssim errors', cv2.cvtColor(grid_ssim_maps, cv2.COLOR_RGB2BGR)) if horizontal: assert (nimgs == 1) disp_rows = vis.make_grid(rows, nCols=2) else: disp_rows = vis.make_grid(rows, nCols=1) wnd_title = 'Predicted Landmarks ' if ds is not None: wnd_title += ds.__class__.__name__ cv2.imshow(wnd_title, cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR)) cv2.waitKey(wait)
def _run_batch(self, data, train_autoencoder, eval=False, ds=None): time_dataloading = time.time() - self.iter_starttime time_proc_start = time.time() iter_stats = {'time_dataloading': time_dataloading} batch = Batch(data, eval=eval) self.saae.zero_grad() self.saae.eval() input_images = batch.images_mod if batch.images_mod is not None else batch.images with torch.set_grad_enabled(self.args.train_encoder): z_sample = self.saae.Q(input_images) iter_stats.update({'z_recon_mean': z_sample.mean().item()}) ####################### # Reconstruction phase ####################### with torch.set_grad_enabled(self.args.train_encoder and not eval): X_recon = self.saae.P(z_sample) with torch.no_grad(): diff = torch.abs(batch.images - X_recon) * 255 loss_recon_l1 = torch.mean(diff) loss_Q = loss_recon_l1 * cfg.W_RECON iter_stats['loss_recon'] = loss_recon_l1.item() l1_dist_per_img = diff.reshape(len(batch.images), -1).mean(dim=1) iter_stats['l1_recon_errors'] = to_numpy(l1_dist_per_img) ####################### # Landmark predictions ####################### train_lmhead = not eval and not args.train_coords lm_preds_max = None with torch.set_grad_enabled(train_lmhead): self.saae.LMH.train(train_lmhead) X_lm_hm = self.saae.LMH(self.saae.P) if batch.lm_heatmaps is not None: loss_lms = F.mse_loss(batch.lm_heatmaps, X_lm_hm) * 100 * 3 if (eval or self._is_printout_iter()): # expensive, so only calculate when every N iterations X_lm_hm = lmutils.decode_heatmap_blob(X_lm_hm) X_lm_hm = lmutils.smooth_heatmaps(X_lm_hm) lm_preds_max = self.saae.heatmaps_to_landmarks(X_lm_hm) iter_stats.update({'loss_Q': loss_Q.item()}) if not args.train_coords: iter_stats.update({'loss_lms': loss_lms.item()}) if (eval or self._is_printout_iter()): lm_gt = to_numpy(batch.landmarks) lm_errs_max = lmutils.calc_landmark_nme_per_img(lm_gt, lm_preds_max, ocular_norm=self.args.ocular_norm, landmarks_to_eval=lmcfg.LANDMARKS_NO_OUTLINE) lm_errs_max_outline = lmutils.calc_landmark_nme_per_img(lm_gt, lm_preds_max, ocular_norm=self.args.ocular_norm, landmarks_to_eval=lmcfg.LANDMARKS_ONLY_OUTLINE) lm_errs_max_all = lmutils.calc_landmark_nme_per_img(lm_gt, lm_preds_max, ocular_norm=self.args.ocular_norm, landmarks_to_eval=lmcfg.ALL_LANDMARKS) nmes = lmutils.calc_landmark_nme(lm_gt, lm_preds_max, ocular_norm=self.args.ocular_norm) # nccs = lmutils.calc_landmark_ncc(batch.images, X_recon, lm_gt) iter_stats.update({'lm_errs_max': lm_errs_max, 'lm_errs_max_all': lm_errs_max_all, 'lm_errs_max_outline': lm_errs_max_outline, 'nmes': nmes, # 'nccs': nccs }) if train_lmhead: if self.args.train_encoder: loss_lms = loss_lms * 80.0 loss_lms.backward() self.optimizer_lm_head.step() if self.args.train_encoder: self.optimizer_Q.step() # statistics iter_stats.update({'epoch': self.epoch, 'timestamp': time.time(), 'iter_time': time.time() - self.iter_starttime, 'time_processing': time.time() - time_proc_start, 'iter': self.iter_in_epoch, 'total_iter': self.total_iter, 'batch_size': len(batch)}) self.iter_starttime = time.time() self.epoch_stats.append(iter_stats) # print stats every N mini-batches if self._is_printout_iter(): self._print_iter_stats(self.epoch_stats[-self.print_interval:]) # Batch visualization # if self._is_printout_iter(): f = 2.0 / cfg.INPUT_SCALE_FACTOR # lmutils.visualize_random_faces(self.saae, 20, 0) lmutils.visualize_batch(batch.images, batch.landmarks, X_recon, X_lm_hm, lm_preds_max, lm_heatmaps=batch.lm_heatmaps, images_mod=batch.images_mod, ds=ds, wait=self.wait, landmarks_to_draw=lmcfg.ALL_LANDMARKS, ocular_norm=args.ocular_norm, f=f, clean=False, overlay_heatmaps_input=False, overlay_heatmaps_recon=False)