def show_pairs(images, features, pairs): dists = np.sqrt(np.sum((features[0] - features[1])**2, axis=1)) ds_utils.denormalize(images[0]) ds_utils.denormalize(images[1]) images[1] = vis.add_error_to_images(images[1], dists, size=2.0, thickness=2, vmin=0, vmax=1) images[1] = vis.add_id_to_images(images[1], pairs.numpy(), size=1.2, thickness=2, color=(1, 0, 1)) thresh = 0.4 corrects = (dists < thresh) == pairs.cpu().numpy() colors = [(0, 1, 0) if c else (1, 0, 0) for c in corrects] images[1] = vis.add_cirle_to_images(images[1], colors) images[0] = vis._to_disp_images(images[0]) img_rows = [ vis.make_grid(imgs, fx=0.75, fy=0.75, nCols=len(dists), normalize=False) for imgs in images ] vis.vis_square(img_rows, nCols=1, normalize=False)
def add_confs(disp_X_recon, lmids, loc): means = lm_confs[:, lmids].mean(axis=1) colors = vis.color_map(to_numpy(1 - means), cmap=plt.cm.jet, vmin=0.0, vmax=0.4) return vis.add_error_to_images(disp_X_recon, means, loc=loc, format_string='{:>4.2f}', colors=colors)
def generate_images(self, z): train_state_D = self.saae.D.training train_state_P = self.saae.P.training self.saae.D.eval() self.saae.P.eval() loc_err_gan = "tr" with torch.no_grad(): X_gen_vis = self.saae.P(z)[:, :3] err_gan_gen = self.saae.D(X_gen_vis) imgs = vis.reconstruct_images(X_gen_vis) self.saae.D.train(train_state_D) self.saae.P.train(train_state_P) return vis.add_error_to_images( imgs, errors=1 - err_gan_gen, loc=loc_err_gan, format_string="{:.2f}", vmax=1.0, )
def visualize_batch(self, batch, X_recon, ssim_maps, nimgs=8, ds=None, wait=0): nimgs = min(nimgs, len(batch)) train_state_D = self.saae.D.training train_state_Q = self.saae.Q.training train_state_P = self.saae.P.training self.saae.D.eval() self.saae.Q.eval() self.saae.P.eval() loc_err_gan = "tr" text_size_errors = 0.65 input_images = vis.reconstruct_images(batch.images[:nimgs]) show_filenames = batch.filenames[:nimgs] target_images = (batch.target_images if batch.target_images is not None else batch.images) disp_images = vis.reconstruct_images(target_images[:nimgs]) # draw GAN score if self.args.with_gan: with torch.no_grad(): err_gan_inputs = self.saae.D(batch.images[:nimgs]) disp_images = vis.add_error_to_images( disp_images, errors=1 - err_gan_inputs, loc=loc_err_gan, format_string="{:>5.2f}", vmax=1.0, ) # disp_images = vis.add_landmarks_to_images(disp_images, batch.landmarks[:nimgs], color=(0,1,0), radius=1, # draw_wireframe=False) rows = [vis.make_grid(disp_images, nCols=nimgs, normalize=False)] recon_images = vis.reconstruct_images(X_recon[:nimgs]) disp_X_recon = recon_images.copy() print_stats = True if print_stats: # lm_ssim_errs = None # if batch.landmarks is not None: # lm_recon_errs = lmutils.calc_landmark_recon_error(batch.images[:nimgs], X_recon[:nimgs], batch.landmarks[:nimgs], reduction='none') # disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_recon_errs, size=text_size_errors, loc='bm', # format_string='({:>3.1f})', vmin=0, vmax=10) # lm_ssim_errs = lmutils.calc_landmark_ssim_error(batch.images[:nimgs], X_recon[:nimgs], batch.landmarks[:nimgs]) # disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_ssim_errs.mean(axis=1), size=text_size_errors, loc='bm-1', # format_string='({:>3.2f})', vmin=0.2, vmax=0.8) X_recon_errs = 255.0 * torch.abs(batch.images - X_recon).reshape( len(batch.images), -1).mean(dim=1) # disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, batch.landmarks[:nimgs], radius=1, color=None, # lm_errs=lm_ssim_errs, draw_wireframe=False) disp_X_recon = vis.add_error_to_images( disp_X_recon[:nimgs], errors=X_recon_errs, size=text_size_errors, format_string="{:>4.1f}", ) if self.args.with_gan: with torch.no_grad(): err_gan = self.saae.D(X_recon[:nimgs]) disp_X_recon = vis.add_error_to_images( disp_X_recon, errors=1 - err_gan, loc=loc_err_gan, format_string="{:>5.2f}", vmax=1.0, ) ssim = np.zeros(nimgs) for i in range(nimgs): data_range = 255.0 if input_images[0].dtype == np.uint8 else 1.0 ssim[i] = compare_ssim( input_images[i], recon_images[i], data_range=data_range, multichannel=True, ) disp_X_recon = vis.add_error_to_images( disp_X_recon, 1 - ssim, loc="bl-1", size=text_size_errors, format_string="{:>4.2f}", vmin=0.2, vmax=0.8, ) if ssim_maps is not None: disp_X_recon = vis.add_error_to_images( disp_X_recon, ssim_maps.reshape(len(ssim_maps), -1).mean(axis=1), size=text_size_errors, loc="bl-2", format_string="{:>4.2f}", vmin=0.0, vmax=0.4, ) rows.append(vis.make_grid(disp_X_recon, nCols=nimgs)) if ssim_maps is not None: disp_ssim_maps = to_numpy( nn.denormalized(ssim_maps)[:nimgs].transpose(0, 2, 3, 1)) if disp_ssim_maps.shape[3] == 1: disp_ssim_maps = disp_ssim_maps.repeat(3, axis=3) for i in range(len(disp_ssim_maps)): disp_ssim_maps[i] = vis.color_map( disp_ssim_maps[i].mean(axis=2), vmin=0.0, vmax=2.0) grid_ssim_maps = vis.make_grid(disp_ssim_maps, nCols=nimgs) cv2.imwrite("ssim errors.jpg", cv2.cvtColor(grid_ssim_maps, cv2.COLOR_RGB2BGR)) self.saae.D.train(train_state_D) self.saae.Q.train(train_state_Q) self.saae.P.train(train_state_P) f = 1 disp_rows = vis.make_grid(rows, nCols=1, normalize=False, fx=f, fy=f) wnd_title = "recon errors " if ds is not None: wnd_title += ds.__class__.__name__ cv2.imwrite(wnd_title + ".jpg", cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR)) cv2.waitKey(wait)
def __train_disenglement_parallel(self, z, Y=None, train=True): iter_stats = {} self.E.train(train) self.G.train(train) self.optimizer_E.zero_grad() self.optimizer_G.zero_grad() # # Autoencoding phase # fts = self.E(z) fp, fi, fs, fe = fts z_recon = self.G(fp, fi, fs, fe) loss_z_recon = F.l1_loss(z, z_recon) * cfg.W_Z_RECON if not cfg.WITH_Z_RECON_LOSS: loss_z_recon *= 0 # # Info min/max phase # loss_I = loss_z_recon loss_G = torch.zeros(1, requires_grad=True).cuda() def calc_err(outputs, target): return np.abs(np.rad2deg(F.l1_loss(outputs, target, reduction='none').detach().cpu().numpy().mean(axis=0))) def cosine_loss(outputs, targets): return (1 - F.cosine_similarity(outputs, targets, dim=1)).mean() if Y[3] is not None and Y[3].sum() > 0: # Has expression -> AffectNet available_factors = [3,3,3] if cfg.WITH_POSE: available_factors = [0] + available_factors elif Y[2][1] is not None: # has vids -> VoxCeleb available_factors = [2] elif Y[1] is not None: # Has identities available_factors = [1,1,1] if cfg.WITH_POSE: available_factors = [0] + available_factors elif Y[0] is not None: # Any dataset with pose available_factors = [0,1,3] lvl = available_factors[self.iter % len(available_factors)] name = self.factors[lvl] try: y = Y[lvl] except TypeError: y = None # if y is not None and name != 'shape': def calc_feature_loss(name, y_f, y, show_triplets=False, wnd_title=None): if name == 'id' or name == 'shape' or name == 'expression': display_images = None if show_triplets: display_images = self.images loss_I_f, err_f = calc_triplet_loss(y_f, y, return_acc=True, images=display_images, feature_name=name, wnd_title=wnd_title) if name == 'expression': loss_I_f *= 2.0 elif name == 'pose': # loss_I_f, err_f = F.l1_loss(y_f, y), calc_err(y_f, y) loss_I_f, err_f = F.mse_loss(y_f, y)*1, calc_err(y_f, y) # loss_I_f, err_f = cosine_loss(y_f, y), calc_err(y_f, y) else: raise ValueError("Unknown feature name!") return loss_I_f, err_f if y is not None and cfg.WITH_FEATURE_LOSS: show_triplets = (self.iter + 1) % self.print_interval == 0 y_f = fts[lvl] loss_I_f, err_f = calc_feature_loss(name, y_f, y, show_triplets=show_triplets) loss_I += cfg.W_FEAT * loss_I_f iter_stats[name+'_loss_f'] = loss_I_f.item() iter_stats[name+'_err_f'] = np.mean(err_f) # train expression classifier if name == 'expression': self.znet.zero_grad() emotion_labels = y[:,0].long() clprobs = self.znet(y_f.detach()) # train only znet # clprobs = self.znet(y_f) # train enoder and znet # loss_cls = self.cross_entropy_loss(clprobs, emotion_labels) loss_cls = self.weighted_CE_loss(clprobs, emotion_labels) acc_cls = calc_acc(clprobs, emotion_labels) if train: loss_cls.backward(retain_graph=False) self.optimizer_znet.step() iter_stats['loss_cls'] = loss_cls.item() iter_stats['acc_cls'] = acc_cls iter_stats['expression_y_probs'] = to_numpy(clprobs) iter_stats['expression_y'] = to_numpy(y) # cycle loss # other_levels = [0,1,2,3] # other_levels.remove(lvl) # shuffle_lvl = np.random.permutation(other_levels)[0] shuffle_lvl = lvl # print("shuffling level {}...".format(shuffle_lvl)) if cfg.WITH_DISENT_CYCLE_LOSS: # z_random = torch.rand_like(z).cuda() # fts_random = self.E(z_random) # create modified feature vectors fts[0] = fts[0].detach() fts[1] = fts[1].detach() fts[2] = fts[2].detach() fts[3] = fts[3].detach() fts_mod = fts.copy() shuffled_ids = torch.randperm(len(fts[shuffle_lvl])) y_mod = None if y is not None: if name == 'shape': y_mod = [y[0][shuffled_ids], y[1][shuffled_ids]] else: y_mod = y[shuffled_ids] fts_mod[shuffle_lvl] = fts[shuffle_lvl][shuffled_ids] # predict full cycle z_random_mod = self.G(*fts_mod) X_random_mod = self.P(z_random_mod)[:,:3] z_random_mod_recon = self.Q(X_random_mod) fts2 = self.E(z_random_mod_recon) # recon error in unmodified part # h = torch.cat([fts_mod[i] for i in range(len(fts_mod)) if i != lvl], dim=1) # h2 = torch.cat([fts2[i] for i in range(len(fts2)) if i != lvl], dim=1) # l1_err_h = torch.abs(h - h2).mean(dim=1) # l1_err_h = torch.abs(torch.cat(fts_mod, dim=1) - torch.cat(fts2, dim=1)).mean(dim=1) # recon error in modified part # l1_err_f = np.rad2deg(to_numpy(torch.abs(fts_mod[lvl] - fts2[lvl]).mean(dim=1))) # recon error in entire vector l1_err = torch.abs(torch.cat(fts_mod, dim=1)[:,3:] - torch.cat(fts2, dim=1)[:,3:]).mean(dim=1) loss_dis_cycle = F.l1_loss(torch.cat(fts_mod, dim=1)[:,3:], torch.cat(fts2, dim=1)[:,3:]) * cfg.W_CYCLE iter_stats['loss_dis_cycle'] = loss_dis_cycle.item() loss_I += loss_dis_cycle # cycle augmentation loss if cfg.WITH_AUGMENTATION_LOSS and y_mod is not None: y_f_2 = fts2[lvl] loss_I_f_2, err_f_2 = calc_feature_loss(name, y_f_2, y_mod, show_triplets=show_triplets, wnd_title='aug') loss_I += loss_I_f_2 * cfg.W_AUG iter_stats[name+'_loss_f_2'] = loss_I_f_2.item() iter_stats[name+'_err_f_2'] = np.mean(err_f_2) # # Adversarial loss of modified generations # GAN = False if GAN and train: eps = 0.00001 # ####################### # # GAN discriminator phase # ####################### update_D = False if update_D: self.D.zero_grad() err_real = self.D(self.images) err_fake = self.D(X_random_mod.detach()) # err_fake = self.D(X_z_recon.detach()) loss_D = -torch.mean(torch.log(err_real + eps) + torch.log(1.0 - err_fake + eps)) * 0.1 loss_D.backward() self.optimizer_D.step() iter_stats.update({'loss_D': loss_D.item()}) ####################### # Generator loss ####################### self.D.zero_grad() err_fake = self.D(X_random_mod) # err_fake = self.D(X_z_recon) loss_G += -torch.mean(torch.log(err_fake + eps)) iter_stats.update({'loss_G': loss_G.item()}) # iter_stats.update({'err_real': err_real.mean().item(), 'err_fake': loss_G.mean().item()}) # debug visualization show = True if show: if (self.iter+1) % self.print_interval in [0,1]: if Y[3] is None: emotion_gt = np.zeros(len(z), dtype=int) emotion_gt_mod = np.zeros(len(z), dtype=int) else: emotion_gt = Y[3][:,0].long() emotion_gt_mod = Y[3][shuffled_ids,0].long() with torch.no_grad(): self.znet.eval() self.G.eval() emotion_preds = torch.max(self.znet(fe.detach()), 1)[1] emotion_mod = torch.max(self.znet(fts_mod[3].detach()), 1)[1] emotion_mod_pred = torch.max(self.znet(fts2[3].detach()), 1)[1] X_recon = self.P(z)[:,:3] X_z_recon = self.P(z_recon)[:,:3] X_random_mod_recon = self.P(self.G(*fts2))[:,:3] self.znet.train(train) self.G.train(train) X_recon_errs = 255.0 * torch.abs(self.images - X_recon).reshape(len(self.images), -1).mean(dim=1) X_z_recon_errs = 255.0 * torch.abs(self.images - X_z_recon).reshape(len(self.images), -1).mean(dim=1) nimgs = 8 disp_input = vis.add_pose_to_images(ds_utils.denormalized(self.images)[:nimgs], Y[0], color=(0, 0, 1.0)) if name == 'expression': disp_input = vis.add_emotion_to_images(disp_input, to_numpy(emotion_gt)) elif name == 'id': disp_input = vis.add_id_to_images(disp_input, to_numpy(Y[1])) disp_recon = vis.add_pose_to_images(ds_utils.denormalized(X_recon)[:nimgs], fts[0]) disp_recon = vis.add_error_to_images(disp_recon, errors=X_recon_errs, format_string='{:.1f}') disp_z_recon = vis.add_pose_to_images(ds_utils.denormalized(X_z_recon)[:nimgs], fts[0]) disp_z_recon = vis.add_emotion_to_images(disp_z_recon, to_numpy(emotion_preds), gt_emotions=to_numpy(emotion_gt) if name=='expression' else None) disp_z_recon = vis.add_error_to_images(disp_z_recon, errors=X_z_recon_errs, format_string='{:.1f}') disp_input_shuffle = vis.add_pose_to_images(ds_utils.denormalized(self.images[shuffled_ids])[:nimgs], fts[0][shuffled_ids]) disp_input_shuffle = vis.add_emotion_to_images(disp_input_shuffle, to_numpy(emotion_gt_mod)) if name == 'id': disp_input_shuffle = vis.add_id_to_images(disp_input_shuffle, to_numpy(Y[1][shuffled_ids])) disp_recon_shuffle = vis.add_pose_to_images(ds_utils.denormalized(X_random_mod)[:nimgs], fts_mod[0], color=(0, 0, 1.0)) disp_recon_shuffle = vis.add_emotion_to_images(disp_recon_shuffle, to_numpy(emotion_mod)) disp_cycle = vis.add_pose_to_images(ds_utils.denormalized(X_random_mod_recon)[:nimgs], fts2[0]) disp_cycle = vis.add_emotion_to_images(disp_cycle, to_numpy(emotion_mod_pred)) disp_cycle = vis.add_error_to_images(disp_cycle, errors=l1_err, format_string='{:.3f}', size=0.6, thickness=2, vmin=0, vmax=0.1) rows = [ # original input images vis.make_grid(disp_input, nCols=nimgs), # reconstructions without disentanglement vis.make_grid(disp_recon, nCols=nimgs), # reconstructions with disentanglement vis.make_grid(disp_z_recon, nCols=nimgs), # source for feature transfer vis.make_grid(disp_input_shuffle, nCols=nimgs), # reconstructions with modified feature vector (direkt) vis.make_grid(disp_recon_shuffle, nCols=nimgs), # reconstructions with modified feature vector (1 iters) vis.make_grid(disp_cycle, nCols=nimgs) ] f = 1.0 / cfg.INPUT_SCALE_FACTOR disp_img = vis.make_grid(rows, nCols=1, normalize=False, fx=f, fy=f) wnd_title = name if self.current_dataset is not None: wnd_title += ' ' + self.current_dataset.__class__.__name__ cv2.imshow(wnd_title, cv2.cvtColor(disp_img, cv2.COLOR_RGB2BGR)) cv2.waitKey(10) loss_I *= cfg.W_DISENT iter_stats['loss_disent'] = loss_I.item() if train: loss_I.backward(retain_graph=True) return z_recon, iter_stats, loss_G[0]
def visualize_batch(images, landmarks, X_recon, X_lm_hm, lm_preds_max, lm_heatmaps=None, images_mod=None, lm_preds_cnn=None, ds=None, wait=0, ssim_maps=None, landmarks_to_draw=lmcfg.ALL_LANDMARKS, ocular_norm='outer', horizontal=False, f=1.0, overlay_heatmaps_input=False, overlay_heatmaps_recon=False, clean=False): gt_color = (0, 255, 0) pred_color = (0, 0, 255) nimgs = min(10, len(images)) images = nn.atleast4d(images)[:nimgs] nme_per_lm = None if landmarks is None: # print('num landmarks', lmcfg.NUM_LANDMARKS) lm_gt = np.zeros((nimgs, lmcfg.NUM_LANDMARKS, 2)) else: lm_gt = nn.atleast3d(to_numpy(landmarks))[:nimgs] nme_per_lm = calc_landmark_nme(lm_gt, lm_preds_max[:nimgs], ocular_norm=ocular_norm) lm_ssim_errs = 1 - calc_landmark_ssim_score(images, X_recon[:nimgs], lm_gt) lm_confs = None # show landmark heatmaps pred_heatmaps = None if X_lm_hm is not None: pred_heatmaps = to_single_channel_heatmap(to_numpy(X_lm_hm[:nimgs])) pred_heatmaps = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in pred_heatmaps ] gt_heatmaps = None if lm_heatmaps is not None: gt_heatmaps = to_single_channel_heatmap( to_numpy(lm_heatmaps[:nimgs])) gt_heatmaps = np.array([ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in gt_heatmaps ]) show_landmark_heatmaps(pred_heatmaps, gt_heatmaps, nimgs, f=1) lm_confs = to_numpy(X_lm_hm).reshape(X_lm_hm.shape[0], X_lm_hm.shape[1], -1).max(axis=2) # resize images for display and scale landmarks accordingly lm_preds_max = lm_preds_max[:nimgs] * f if lm_preds_cnn is not None: lm_preds_cnn = lm_preds_cnn[:nimgs] * f lm_gt *= f input_images = vis._to_disp_images(images[:nimgs], denorm=True) if images_mod is not None: disp_images = vis._to_disp_images(images_mod[:nimgs], denorm=True) else: disp_images = vis._to_disp_images(images[:nimgs], denorm=True) disp_images = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in disp_images ] recon_images = vis._to_disp_images(X_recon[:nimgs], denorm=True) disp_X_recon = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in recon_images.copy() ] # overlay landmarks on input images if pred_heatmaps is not None and overlay_heatmaps_input: disp_images = [ vis.overlay_heatmap(disp_images[i], pred_heatmaps[i]) for i in range(len(pred_heatmaps)) ] if pred_heatmaps is not None and overlay_heatmaps_recon: disp_X_recon = [ vis.overlay_heatmap(disp_X_recon[i], pred_heatmaps[i]) for i in range(len(pred_heatmaps)) ] # # Show input images # disp_images = vis.add_landmarks_to_images(disp_images, lm_gt[:nimgs], color=gt_color) disp_images = vis.add_landmarks_to_images(disp_images, lm_preds_max[:nimgs], lm_errs=nme_per_lm, color=pred_color, draw_wireframe=False, gt_landmarks=lm_gt, draw_gt_offsets=True) # disp_images = vis.add_landmarks_to_images(disp_images, lm_gt[:nimgs], color=(1,1,1), radius=1, # draw_dots=True, draw_wireframe=True, landmarks_to_draw=landmarks_to_draw) # disp_images = vis.add_landmarks_to_images(disp_images, lm_preds_max[:nimgs], lm_errs=nme_per_lm, # color=(1.0, 0.0, 0.0), # draw_dots=True, draw_wireframe=True, radius=1, # gt_landmarks=lm_gt, draw_gt_offsets=False, # landmarks_to_draw=landmarks_to_draw) # # Show reconstructions # X_recon_errs = 255.0 * torch.abs(images - X_recon[:nimgs]).reshape( len(images), -1).mean(dim=1) if not clean: disp_X_recon = vis.add_error_to_images(disp_X_recon[:nimgs], errors=X_recon_errs, format_string='{:>4.1f}') # modes of heatmaps # disp_X_recon = [overlay_heatmap(disp_X_recon[i], pred_heatmaps[i]) for i in range(len(pred_heatmaps))] if not clean: lm_errs_max = calc_landmark_nme_per_img( lm_gt, lm_preds_max, ocular_norm=ocular_norm, landmarks_to_eval=lmcfg.LANDMARKS_NO_OUTLINE) lm_errs_max_outline = calc_landmark_nme_per_img( lm_gt, lm_preds_max, ocular_norm=ocular_norm, landmarks_to_eval=lmcfg.LANDMARKS_ONLY_OUTLINE) lm_errs_max_all = calc_landmark_nme_per_img( lm_gt, lm_preds_max, ocular_norm=ocular_norm, landmarks_to_eval=lmcfg.ALL_LANDMARKS) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs_max, loc='br-2', format_string='{:>5.2f}', vmax=15) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs_max_outline, loc='br-1', format_string='{:>5.2f}', vmax=15) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs_max_all, loc='br', format_string='{:>5.2f}', vmax=15) disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, lm_gt, color=gt_color, draw_wireframe=True) # disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, lm_preds_max[:nimgs], # color=pred_color, draw_wireframe=False, # lm_errs=nme_per_lm, lm_confs=lm_confs, # lm_rec_errs=lm_ssim_errs, gt_landmarks=lm_gt, # draw_gt_offsets=True, draw_dots=True) disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, lm_preds_max[:nimgs], color=pred_color, draw_wireframe=True, gt_landmarks=lm_gt, draw_gt_offsets=True, lm_errs=nme_per_lm, draw_dots=True, radius=2) def add_confs(disp_X_recon, lmids, loc): means = lm_confs[:, lmids].mean(axis=1) colors = vis.color_map(to_numpy(1 - means), cmap=plt.cm.jet, vmin=0.0, vmax=0.4) return vis.add_error_to_images(disp_X_recon, means, loc=loc, format_string='{:>4.2f}', colors=colors) # disp_X_recon = add_confs(disp_X_recon, lmcfg.LANDMARKS_NO_OUTLINE, 'bm-2') # disp_X_recon = add_confs(disp_X_recon, lmcfg.LANDMARKS_ONLY_OUTLINE, 'bm-1') # disp_X_recon = add_confs(disp_X_recon, lmcfg.ALL_LANDMARKS, 'bm') # print ssim errors ssim = np.zeros(nimgs) for i in range(nimgs): ssim[i] = compare_ssim(input_images[i], recon_images[i], data_range=1.0, multichannel=True) if not clean: disp_X_recon = vis.add_error_to_images(disp_X_recon, 1 - ssim, loc='bl-1', format_string='{:>4.2f}', vmax=0.8, vmin=0.2) # print ssim torch errors if ssim_maps is not None and not clean: disp_X_recon = vis.add_error_to_images(disp_X_recon, ssim_maps.reshape( len(ssim_maps), -1).mean(axis=1), loc='bl-2', format_string='{:>4.2f}', vmin=0.0, vmax=0.4) rows = [vis.make_grid(disp_images, nCols=nimgs, normalize=False)] rows.append(vis.make_grid(disp_X_recon, nCols=nimgs)) if ssim_maps is not None: disp_ssim_maps = to_numpy( ds_utils.denormalized(ssim_maps)[:nimgs].transpose(0, 2, 3, 1)) for i in range(len(disp_ssim_maps)): disp_ssim_maps[i] = vis.color_map(disp_ssim_maps[i].mean(axis=2), vmin=0.0, vmax=2.0) grid_ssim_maps = vis.make_grid(disp_ssim_maps, nCols=nimgs, fx=f, fy=f) cv2.imshow('ssim errors', cv2.cvtColor(grid_ssim_maps, cv2.COLOR_RGB2BGR)) if horizontal: assert (nimgs == 1) disp_rows = vis.make_grid(rows, nCols=2) else: disp_rows = vis.make_grid(rows, nCols=1) wnd_title = 'Predicted Landmarks ' if ds is not None: wnd_title += ds.__class__.__name__ cv2.imshow(wnd_title, cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR)) cv2.waitKey(wait)
def visualize_batch(batch, X_recon, X_lm_hm, lm_preds_max, lm_preds_cnn=None, ds=None, wait=0, ssim_maps=None, landmarks_to_draw=lmcfg.LANDMARKS_TO_EVALUATE, ocular_norm='pupil', horizontal=False, f=1.0): nimgs = min(10, len(batch)) gt_color = (0, 1, 0) lm_confs = None # show landmark heatmaps pred_heatmaps = None if X_lm_hm is not None: pred_heatmaps = to_single_channel_heatmap(to_numpy(X_lm_hm[:nimgs])) pred_heatmaps = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in pred_heatmaps ] if batch.lm_heatmaps is not None: gt_heatmaps = to_single_channel_heatmap( to_numpy(batch.lm_heatmaps[:nimgs])) gt_heatmaps = np.array([ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in gt_heatmaps ]) show_landmark_heatmaps(pred_heatmaps, gt_heatmaps, nimgs, f=1) lm_confs = to_numpy(X_lm_hm).reshape(X_lm_hm.shape[0], X_lm_hm.shape[1], -1).max(axis=2) # scale landmarks lm_preds_max = lm_preds_max[:nimgs] * f if lm_preds_cnn is not None: lm_preds_cnn = lm_preds_cnn[:nimgs] * f lm_gt = to_numpy(batch.landmarks[:nimgs]) * f if lm_gt.shape[1] == 98: lm_gt = convert_landmarks(lm_gt, LM98_TO_LM68) input_images = vis._to_disp_images(batch.images[:nimgs], denorm=True) if batch.images_mod is not None: disp_images = vis._to_disp_images(batch.images_mod[:nimgs], denorm=True) else: disp_images = vis._to_disp_images(batch.images[:nimgs], denorm=True) disp_images = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in disp_images ] recon_images = vis._to_disp_images(X_recon[:nimgs], denorm=True) disp_X_recon = [ cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST) for im in recon_images.copy() ] # draw landmarks to input images if pred_heatmaps is not None: disp_images = [ vis.overlay_heatmap(disp_images[i], pred_heatmaps[i]) for i in range(len(pred_heatmaps)) ] nme_per_lm = calc_landmark_nme(lm_gt, lm_preds_max, ocular_norm=ocular_norm) lm_ssim_errs = calc_landmark_ssim_error(batch.images[:nimgs], X_recon[:nimgs], batch.landmarks[:nimgs]) # # Show input images # disp_images = vis.add_landmarks_to_images( disp_images, lm_gt[:nimgs], color=gt_color, draw_dots=True, draw_wireframe=False, landmarks_to_draw=landmarks_to_draw) disp_images = vis.add_landmarks_to_images( disp_images, lm_preds_max[:nimgs], lm_errs=nme_per_lm, color=(0.0, 0.0, 1.0), draw_dots=True, draw_wireframe=False, gt_landmarks=lm_gt, draw_gt_offsets=True, landmarks_to_draw=landmarks_to_draw) # if lm_preds_cnn is not None: # disp_images = vis.add_landmarks_to_images(disp_images, lm_preds_cnn, color=(1, 1, 0), # gt_landmarks=lm_gt, draw_gt_offsets=False, # draw_wireframe=True, landmarks_to_draw=landmarks_to_draw) rows = [vis.make_grid(disp_images, nCols=nimgs, normalize=False)] # # Show reconstructions # X_recon_errs = 255.0 * torch.abs(batch.images - X_recon).reshape( len(batch.images), -1).mean(dim=1) disp_X_recon = vis.add_error_to_images(disp_X_recon[:nimgs], errors=X_recon_errs, format_string='{:>4.1f}') # modes of heatmaps # disp_X_recon = [overlay_heatmap(disp_X_recon[i], pred_heatmaps[i]) for i in range(len(pred_heatmaps))] lm_errs_max = calc_landmark_nme_per_img( lm_gt, lm_preds_max, ocular_norm=ocular_norm, landmarks_to_eval=lmcfg.LANDMARKS_NO_OUTLINE) lm_errs_max_outline = calc_landmark_nme_per_img( lm_gt, lm_preds_max, ocular_norm=ocular_norm, landmarks_to_eval=lmcfg.LANDMARKS_ONLY_OUTLINE) lm_errs_max_all = calc_landmark_nme_per_img( lm_gt, lm_preds_max, ocular_norm=ocular_norm, landmarks_to_eval=lmcfg.ALL_LANDMARKS) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs_max, loc='br-2', format_string='{:>5.2f}', vmax=15) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs_max_outline, loc='br-1', format_string='{:>5.2f}', vmax=15) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs_max_all, loc='br', format_string='{:>5.2f}', vmax=15) disp_X_recon = vis.add_landmarks_to_images( disp_X_recon, lm_preds_max[:nimgs], color=(0, 0, 1), landmarks_to_draw=landmarks_to_draw, draw_wireframe=False, lm_errs=nme_per_lm, # lm_confs=lm_confs, lm_confs=1 - lm_ssim_errs, gt_landmarks=lm_gt, draw_gt_offsets=True, draw_dots=True) disp_X_recon = vis.add_landmarks_to_images( disp_X_recon, lm_gt, color=gt_color, draw_wireframe=False, landmarks_to_draw=landmarks_to_draw) # landmarks from CNN prediction if lm_preds_cnn is not None: nme_per_lm = calc_landmark_nme(lm_gt, lm_preds_cnn, ocular_norm=ocular_norm) disp_X_recon = vis.add_landmarks_to_images( disp_X_recon, lm_preds_cnn, color=(1, 1, 0), landmarks_to_draw=lmcfg.ALL_LANDMARKS, draw_wireframe=False, lm_errs=nme_per_lm, gt_landmarks=lm_gt, draw_gt_offsets=True, draw_dots=True, offset_line_color=(1, 1, 1)) lm_errs_cnn = calc_landmark_nme_per_img( lm_gt, lm_preds_cnn, ocular_norm=ocular_norm, landmarks_to_eval=landmarks_to_draw) lm_errs_cnn_outline = calc_landmark_nme_per_img( lm_gt, lm_preds_cnn, ocular_norm=ocular_norm, landmarks_to_eval=lmcfg.LANDMARKS_ONLY_OUTLINE) lm_errs_cnn_all = calc_landmark_nme_per_img( lm_gt, lm_preds_cnn, ocular_norm=ocular_norm, landmarks_to_eval=lmcfg.ALL_LANDMARKS) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs_cnn, loc='tr', format_string='{:>5.2f}', vmax=15) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs_cnn_outline, loc='tr+1', format_string='{:>5.2f}', vmax=15) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs_cnn_all, loc='tr+2', format_string='{:>5.2f}', vmax=15) # print ssim errors ssim = np.zeros(nimgs) for i in range(nimgs): ssim[i] = compare_ssim(input_images[i], recon_images[i], data_range=1.0, multichannel=True) disp_X_recon = vis.add_error_to_images(disp_X_recon, 1 - ssim, loc='bl-1', format_string='{:>4.2f}', vmax=0.8, vmin=0.2) # print ssim torch errors if ssim_maps is not None: disp_X_recon = vis.add_error_to_images(disp_X_recon, ssim_maps.reshape( len(ssim_maps), -1).mean(axis=1), loc='bl-2', format_string='{:>4.2f}', vmin=0.0, vmax=0.4) rows.append(vis.make_grid(disp_X_recon, nCols=nimgs)) if ssim_maps is not None: disp_ssim_maps = to_numpy( ds_utils.denormalized(ssim_maps)[:nimgs].transpose(0, 2, 3, 1)) for i in range(len(disp_ssim_maps)): disp_ssim_maps[i] = vis.color_map(disp_ssim_maps[i].mean(axis=2), vmin=0.0, vmax=2.0) grid_ssim_maps = vis.make_grid(disp_ssim_maps, nCols=nimgs, fx=f, fy=f) cv2.imshow('ssim errors', cv2.cvtColor(grid_ssim_maps, cv2.COLOR_RGB2BGR)) X_gen_lm_hm = None X_gen_vis = None show_random_faces = False if show_random_faces: with torch.no_grad(): z_random = self.enc_rand(nimgs, self.saae.z_dim).cuda() outputs = self.saae.P(z_random) X_gen_vis = outputs[:, :3] if outputs.shape[1] > 3: X_gen_lm_hm = outputs[:, 3:] disp_X_gen = to_numpy( ds_utils.denormalized(X_gen_vis)[:nimgs].permute(0, 2, 3, 1)) if X_gen_lm_hm is not None: if lmcfg.LANDMARK_TARGET == 'colored': gen_heatmaps = [to_image(X_gen_lm_hm[i]) for i in range(nimgs)] elif lmcfg.LANDMARK_TARGET == 'multi_channel': X_gen_lm_hm = X_gen_lm_hm.max(dim=1)[0] gen_heatmaps = [to_image(X_gen_lm_hm[i]) for i in range(nimgs)] else: gen_heatmaps = [ to_image(X_gen_lm_hm[i, 0]) for i in range(nimgs) ] disp_X_gen = [ vis.overlay_heatmap(disp_X_gen[i], gen_heatmaps[i]) for i in range(len(pred_heatmaps)) ] # inputs = torch.cat([X_gen_vis, X_gen_lm_hm.detach()], dim=1) inputs = X_gen_lm_hm.detach() # disabled for multi_channel LM targets # lm_gen_preds = self.saae.lm_coords(inputs).reshape(len(inputs), -1, 2) # disp_X_gen = vis.add_landmarks_to_images(disp_X_gen, lm_gen_preds[:nimgs], color=(0,1,1)) disp_gen_heatmaps = [ vis.color_map(hm, vmin=0, vmax=1.0) for hm in gen_heatmaps ] img_gen_heatmaps = cv2.resize(vis.make_grid(disp_gen_heatmaps, nCols=nimgs, padval=0), None, fx=1.0, fy=1.0) cv2.imshow('generated landmarks', cv2.cvtColor(img_gen_heatmaps, cv2.COLOR_RGB2BGR)) rows.append(vis.make_grid(disp_X_gen, nCols=nimgs)) # self.saae.D.train(train_state_D) # self.saae.Q.train(train_state_Q) # self.saae.P.train(train_state_P) if horizontal: assert (nimgs == 1) disp_rows = vis.make_grid(rows, nCols=2) else: disp_rows = vis.make_grid(rows, nCols=1) wnd_title = 'recon errors ' if ds is not None: wnd_title += ds.__class__.__name__ cv2.imshow(wnd_title, cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR)) cv2.waitKey(wait)
def draw_results(X_resized, X_recon, levels_z=None, landmarks=None, landmarks_pred=None, cs_errs=None, ncols=15, fx=0.5, fy=0.5, additional_status_text=''): clean_images = True if clean_images: landmarks = None nimgs = len(X_resized) ncols = min(ncols, nimgs) img_size = X_recon.shape[-1] disp_X = vis.to_disp_images(X_resized, denorm=True) disp_X_recon = vis.to_disp_images(X_recon, denorm=True) # reconstruction error in pixels l1_dists = 255.0 * to_numpy( (X_resized - X_recon).abs().reshape(len(disp_X), -1).mean(dim=1)) # SSIM errors ssim = np.zeros(nimgs) for i in range(nimgs): ssim[i] = compare_ssim(disp_X[i], disp_X_recon[i], data_range=1.0, multichannel=True) landmarks = to_numpy(landmarks) cs_errs = to_numpy(cs_errs) text_size = img_size / 256 text_thickness = 2 # # Visualise resized input images and reconstructed images # if landmarks is not None: disp_X = vis.add_landmarks_to_images( disp_X, landmarks, draw_wireframe=False, landmarks_to_draw=lmcfg.LANDMARKS_19) disp_X_recon = vis.add_landmarks_to_images( disp_X_recon, landmarks, draw_wireframe=False, landmarks_to_draw=lmcfg.LANDMARKS_19) if landmarks_pred is not None: disp_X = vis.add_landmarks_to_images(disp_X, landmarks_pred, color=(1, 0, 0)) disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, landmarks_pred, color=(1, 0, 0)) if not clean_images: disp_X_recon = vis.add_error_to_images(disp_X_recon, l1_dists, format_string='{:.1f}', size=text_size, thickness=text_thickness) disp_X_recon = vis.add_error_to_images(disp_X_recon, 1 - ssim, loc='bl-1', format_string='{:>4.2f}', vmax=0.8, vmin=0.2, size=text_size, thickness=text_thickness) if cs_errs is not None: disp_X_recon = vis.add_error_to_images(disp_X_recon, cs_errs, loc='bl-2', format_string='{:>4.2f}', vmax=0.0, vmin=0.4, size=text_size, thickness=text_thickness) # landmark errors lm_errs = np.zeros(1) if landmarks is not None: try: from landmarks import lmutils lm_errs = lmutils.calc_landmark_nme_per_img( landmarks, landmarks_pred) disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_errs, loc='br', format_string='{:>5.2f}', vmax=15, size=img_size / 256, thickness=2) except: pass img_input = vis.make_grid(disp_X, nCols=ncols, normalize=False) img_recon = vis.make_grid(disp_X_recon, nCols=ncols, normalize=False) img_input = cv2.resize(img_input, None, fx=fx, fy=fy, interpolation=cv2.INTER_CUBIC) img_recon = cv2.resize(img_recon, None, fx=fx, fy=fy, interpolation=cv2.INTER_CUBIC) img_stack = [img_input, img_recon] # # Visualise hidden layers # VIS_HIDDEN = False if VIS_HIDDEN: img_z = vis.draw_z_vecs(levels_z, size=(img_size, 30), ncols=ncols) img_z = cv2.resize(img_z, dsize=(img_input.shape[1], img_z.shape[0]), interpolation=cv2.INTER_NEAREST) img_stack.append(img_z) cs_errs_mean = np.mean(cs_errs) if cs_errs is not None else np.nan status_bar_text = ("l1 recon err: {:.2f}px " "ssim: {:.3f}({:.3f}) " "lms err: {:2} {}").format(l1_dists.mean(), cs_errs_mean, 1 - ssim.mean(), lm_errs.mean(), additional_status_text) img_status_bar = vis.draw_status_bar(status_bar_text, status_bar_width=img_input.shape[1], status_bar_height=30, dtype=img_input.dtype) img_stack.append(img_status_bar) return np.vstack(img_stack)
def calc_triplet_loss(outputs, c, return_acc=False, images=None, feature_name=None, wnd_title=None): margin = 0.2 eps = 1e-8 debug = False is_expressions = (not isinstance( c, list)) and len(c.shape) > 1 and c.shape[1] == 3 pos_id, neg_id = make_triplets(outputs, c, debug=debug) X, P, N = outputs[:, :], outputs[pos_id, :], outputs[neg_id, :] dpos = torch.sqrt(torch.sum((X - P)**2, dim=1) + eps) dneg = torch.sqrt(torch.sum((X - N)**2, dim=1) + eps) loss = torch.mean( torch.clamp(dpos - dneg + margin, min=0.0, max=margin * 2.0)) # show triplets if images is not None: from utils import vis from datasets import ds_utils if debug and is_expressions: for i in range(10): print(c[:, 0][i].item(), c[pos_id, 0][i].item(), c[neg_id, 0][i].item()) # ids, vids = c[0], c[1] # print(vids[:5]) # print(vids[pos_id][:5]) # print(vids[neg_id][:5]) nimgs = 20 losses = to_numpy( torch.clamp(dpos - dneg + margin, min=0.0, max=margin * 2.0)) # print("Acc: ", 1 - sum(dpos[:nimgs] >= dneg[:nimgs]).item()/float(len(dpos[:nimgs]))) # print("L : ", losses.mean()) images_ref = ds_utils.denormalized(images[:nimgs].clone()) images_pos = ds_utils.denormalized(images[pos_id][:nimgs].clone()) images_neg = ds_utils.denormalized(images[neg_id][:nimgs].clone()) colors = [(0, 1, 0) if c else (1, 0, 0) for c in dpos < dneg] f = 0.75 images_ref = vis.add_error_to_images(vis.add_cirle_to_images( images_ref, colors), losses, size=1.0, vmin=0, vmax=0.5, thickness=2, format_string='{:.2f}') images_pos = vis.add_error_to_images(images_pos, to_numpy(dpos), size=1.0, vmin=0.5, vmax=1.0, thickness=2, format_string='{:.2f}') images_neg = vis.add_error_to_images(images_neg, to_numpy(dneg), size=1.0, vmin=0.5, vmax=1.0, thickness=2, format_string='{:.2f}') if is_expressions: emotions = to_numpy(c[:, 0]).astype(int) images_ref = vis.add_emotion_to_images(images_ref, emotions) images_pos = vis.add_emotion_to_images(images_pos, emotions[pos_id]) images_neg = vis.add_emotion_to_images(images_neg, emotions[neg_id]) elif feature_name == 'id': ids = to_numpy(c).astype(int) images_ref = vis.add_id_to_images(images_ref, ids, loc='tr') images_pos = vis.add_id_to_images(images_pos, ids[pos_id], loc='tr') images_neg = vis.add_id_to_images(images_neg, ids[neg_id], loc='tr') img_ref = vis.make_grid(images_ref, nCols=nimgs, padsize=1, fx=f, fy=f) img_pos = vis.make_grid(images_pos, nCols=nimgs, padsize=1, fx=f, fy=f) img_neg = vis.make_grid(images_neg, nCols=nimgs, padsize=1, fx=f, fy=f) title = 'triplets' if feature_name is not None: title += " " + feature_name if wnd_title is not None: title += " " + wnd_title vis.vis_square([img_ref, img_pos, img_neg], nCols=1, padsize=1, normalize=False, wait=10, title=title) # plt.plot(to_numpy((X[:nimgs]-P[:nimgs]).abs()), 'b') # plt.plot(to_numpy((X[:nimgs]-N[:nimgs]).abs()), 'r') # plt.show() if return_acc: return loss, sum(dpos >= dneg).item() / float(len(dpos)) else: return loss