예제 #1
0
def show_pairs(images, features, pairs):
    dists = np.sqrt(np.sum((features[0] - features[1])**2, axis=1))
    ds_utils.denormalize(images[0])
    ds_utils.denormalize(images[1])
    images[1] = vis.add_error_to_images(images[1],
                                        dists,
                                        size=2.0,
                                        thickness=2,
                                        vmin=0,
                                        vmax=1)
    images[1] = vis.add_id_to_images(images[1],
                                     pairs.numpy(),
                                     size=1.2,
                                     thickness=2,
                                     color=(1, 0, 1))
    thresh = 0.4
    corrects = (dists < thresh) == pairs.cpu().numpy()
    colors = [(0, 1, 0) if c else (1, 0, 0) for c in corrects]
    images[1] = vis.add_cirle_to_images(images[1], colors)
    images[0] = vis._to_disp_images(images[0])
    img_rows = [
        vis.make_grid(imgs,
                      fx=0.75,
                      fy=0.75,
                      nCols=len(dists),
                      normalize=False) for imgs in images
    ]
    vis.vis_square(img_rows, nCols=1, normalize=False)
예제 #2
0
파일: lmutils.py 프로젝트: PaperID8601/3FR
 def add_confs(disp_X_recon, lmids, loc):
     means = lm_confs[:, lmids].mean(axis=1)
     colors = vis.color_map(to_numpy(1 - means),
                            cmap=plt.cm.jet,
                            vmin=0.0,
                            vmax=0.4)
     return vis.add_error_to_images(disp_X_recon,
                                    means,
                                    loc=loc,
                                    format_string='{:>4.2f}',
                                    colors=colors)
 def generate_images(self, z):
     train_state_D = self.saae.D.training
     train_state_P = self.saae.P.training
     self.saae.D.eval()
     self.saae.P.eval()
     loc_err_gan = "tr"
     with torch.no_grad():
         X_gen_vis = self.saae.P(z)[:, :3]
         err_gan_gen = self.saae.D(X_gen_vis)
     imgs = vis.reconstruct_images(X_gen_vis)
     self.saae.D.train(train_state_D)
     self.saae.P.train(train_state_P)
     return vis.add_error_to_images(
         imgs,
         errors=1 - err_gan_gen,
         loc=loc_err_gan,
         format_string="{:.2f}",
         vmax=1.0,
     )
    def visualize_batch(self,
                        batch,
                        X_recon,
                        ssim_maps,
                        nimgs=8,
                        ds=None,
                        wait=0):

        nimgs = min(nimgs, len(batch))
        train_state_D = self.saae.D.training
        train_state_Q = self.saae.Q.training
        train_state_P = self.saae.P.training
        self.saae.D.eval()
        self.saae.Q.eval()
        self.saae.P.eval()

        loc_err_gan = "tr"
        text_size_errors = 0.65

        input_images = vis.reconstruct_images(batch.images[:nimgs])
        show_filenames = batch.filenames[:nimgs]
        target_images = (batch.target_images
                         if batch.target_images is not None else batch.images)
        disp_images = vis.reconstruct_images(target_images[:nimgs])

        # draw GAN score
        if self.args.with_gan:
            with torch.no_grad():
                err_gan_inputs = self.saae.D(batch.images[:nimgs])
            disp_images = vis.add_error_to_images(
                disp_images,
                errors=1 - err_gan_inputs,
                loc=loc_err_gan,
                format_string="{:>5.2f}",
                vmax=1.0,
            )

        # disp_images = vis.add_landmarks_to_images(disp_images, batch.landmarks[:nimgs], color=(0,1,0), radius=1,
        #                                           draw_wireframe=False)
        rows = [vis.make_grid(disp_images, nCols=nimgs, normalize=False)]

        recon_images = vis.reconstruct_images(X_recon[:nimgs])
        disp_X_recon = recon_images.copy()

        print_stats = True
        if print_stats:
            # lm_ssim_errs = None
            # if batch.landmarks is not None:
            #     lm_recon_errs = lmutils.calc_landmark_recon_error(batch.images[:nimgs], X_recon[:nimgs], batch.landmarks[:nimgs], reduction='none')
            #     disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_recon_errs, size=text_size_errors, loc='bm',
            #                                            format_string='({:>3.1f})', vmin=0, vmax=10)
            #     lm_ssim_errs = lmutils.calc_landmark_ssim_error(batch.images[:nimgs], X_recon[:nimgs], batch.landmarks[:nimgs])
            #     disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_ssim_errs.mean(axis=1), size=text_size_errors, loc='bm-1',
            #                                            format_string='({:>3.2f})', vmin=0.2, vmax=0.8)

            X_recon_errs = 255.0 * torch.abs(batch.images - X_recon).reshape(
                len(batch.images), -1).mean(dim=1)
            # disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, batch.landmarks[:nimgs], radius=1, color=None,
            #                                            lm_errs=lm_ssim_errs, draw_wireframe=False)
            disp_X_recon = vis.add_error_to_images(
                disp_X_recon[:nimgs],
                errors=X_recon_errs,
                size=text_size_errors,
                format_string="{:>4.1f}",
            )
            if self.args.with_gan:
                with torch.no_grad():
                    err_gan = self.saae.D(X_recon[:nimgs])
                disp_X_recon = vis.add_error_to_images(
                    disp_X_recon,
                    errors=1 - err_gan,
                    loc=loc_err_gan,
                    format_string="{:>5.2f}",
                    vmax=1.0,
                )

            ssim = np.zeros(nimgs)
            for i in range(nimgs):
                data_range = 255.0 if input_images[0].dtype == np.uint8 else 1.0
                ssim[i] = compare_ssim(
                    input_images[i],
                    recon_images[i],
                    data_range=data_range,
                    multichannel=True,
                )
            disp_X_recon = vis.add_error_to_images(
                disp_X_recon,
                1 - ssim,
                loc="bl-1",
                size=text_size_errors,
                format_string="{:>4.2f}",
                vmin=0.2,
                vmax=0.8,
            )

            if ssim_maps is not None:
                disp_X_recon = vis.add_error_to_images(
                    disp_X_recon,
                    ssim_maps.reshape(len(ssim_maps), -1).mean(axis=1),
                    size=text_size_errors,
                    loc="bl-2",
                    format_string="{:>4.2f}",
                    vmin=0.0,
                    vmax=0.4,
                )

        rows.append(vis.make_grid(disp_X_recon, nCols=nimgs))

        if ssim_maps is not None:
            disp_ssim_maps = to_numpy(
                nn.denormalized(ssim_maps)[:nimgs].transpose(0, 2, 3, 1))
            if disp_ssim_maps.shape[3] == 1:
                disp_ssim_maps = disp_ssim_maps.repeat(3, axis=3)
            for i in range(len(disp_ssim_maps)):
                disp_ssim_maps[i] = vis.color_map(
                    disp_ssim_maps[i].mean(axis=2), vmin=0.0, vmax=2.0)
            grid_ssim_maps = vis.make_grid(disp_ssim_maps, nCols=nimgs)
            cv2.imwrite("ssim errors.jpg",
                        cv2.cvtColor(grid_ssim_maps, cv2.COLOR_RGB2BGR))

        self.saae.D.train(train_state_D)
        self.saae.Q.train(train_state_Q)
        self.saae.P.train(train_state_P)

        f = 1
        disp_rows = vis.make_grid(rows, nCols=1, normalize=False, fx=f, fy=f)
        wnd_title = "recon errors "
        if ds is not None:
            wnd_title += ds.__class__.__name__
        cv2.imwrite(wnd_title + ".jpg",
                    cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR))
        cv2.waitKey(wait)
예제 #5
0
    def __train_disenglement_parallel(self, z, Y=None, train=True):
        iter_stats = {}

        self.E.train(train)
        self.G.train(train)

        self.optimizer_E.zero_grad()
        self.optimizer_G.zero_grad()

        #
        # Autoencoding phase
        #

        fts = self.E(z)
        fp, fi, fs, fe = fts

        z_recon = self.G(fp, fi, fs, fe)

        loss_z_recon = F.l1_loss(z, z_recon) * cfg.W_Z_RECON
        if not cfg.WITH_Z_RECON_LOSS:
            loss_z_recon *= 0

        #
        # Info min/max phase
        #

        loss_I = loss_z_recon
        loss_G = torch.zeros(1, requires_grad=True).cuda()

        def calc_err(outputs, target):
            return np.abs(np.rad2deg(F.l1_loss(outputs, target, reduction='none').detach().cpu().numpy().mean(axis=0)))

        def cosine_loss(outputs, targets):
            return (1 - F.cosine_similarity(outputs, targets, dim=1)).mean()

        if Y[3] is not None and Y[3].sum() > 0:  # Has expression -> AffectNet
            available_factors = [3,3,3]
            if cfg.WITH_POSE:
                available_factors = [0] + available_factors
        elif Y[2][1] is not None:  # has vids -> VoxCeleb
            available_factors = [2]
        elif Y[1] is not None:  # Has identities
            available_factors = [1,1,1]
            if cfg.WITH_POSE:
                available_factors = [0] + available_factors
        elif Y[0] is not None: # Any dataset with pose
            available_factors = [0,1,3]

        lvl = available_factors[self.iter % len(available_factors)]

        name = self.factors[lvl]
        try:
            y = Y[lvl]
        except TypeError:
            y = None

        # if y is not None and name != 'shape':
        def calc_feature_loss(name, y_f, y, show_triplets=False, wnd_title=None):
            if name == 'id' or name == 'shape' or name == 'expression':
                display_images = None
                if show_triplets:
                    display_images = self.images
                loss_I_f, err_f = calc_triplet_loss(y_f, y, return_acc=True, images=display_images, feature_name=name,
                                                    wnd_title=wnd_title)
                if name == 'expression':
                    loss_I_f *= 2.0
            elif name == 'pose':
                # loss_I_f, err_f = F.l1_loss(y_f, y), calc_err(y_f, y)
                loss_I_f, err_f = F.mse_loss(y_f, y)*1, calc_err(y_f, y)
                # loss_I_f, err_f = cosine_loss(y_f, y), calc_err(y_f, y)
            else:
                raise ValueError("Unknown feature name!")
            return loss_I_f, err_f


        if y is not None and cfg.WITH_FEATURE_LOSS:

            show_triplets = (self.iter + 1) % self.print_interval  == 0

            y_f = fts[lvl]
            loss_I_f, err_f = calc_feature_loss(name, y_f, y, show_triplets=show_triplets)

            loss_I += cfg.W_FEAT * loss_I_f

            iter_stats[name+'_loss_f'] = loss_I_f.item()
            iter_stats[name+'_err_f'] = np.mean(err_f)

            # train expression classifier
            if name == 'expression':
                self.znet.zero_grad()
                emotion_labels = y[:,0].long()
                clprobs = self.znet(y_f.detach())  # train only znet
                # clprobs = self.znet(y_f)  # train enoder and znet
                # loss_cls = self.cross_entropy_loss(clprobs, emotion_labels)
                loss_cls = self.weighted_CE_loss(clprobs, emotion_labels)

                acc_cls = calc_acc(clprobs, emotion_labels)
                if train:
                    loss_cls.backward(retain_graph=False)
                self.optimizer_znet.step()
                iter_stats['loss_cls'] = loss_cls.item()
                iter_stats['acc_cls'] = acc_cls
                iter_stats['expression_y_probs'] = to_numpy(clprobs)
                iter_stats['expression_y'] = to_numpy(y)


        # cycle loss
        # other_levels = [0,1,2,3]
        # other_levels.remove(lvl)
        # shuffle_lvl = np.random.permutation(other_levels)[0]
        shuffle_lvl = lvl
        # print("shuffling level {}...".format(shuffle_lvl))
        if cfg.WITH_DISENT_CYCLE_LOSS:
            # z_random = torch.rand_like(z).cuda()
            # fts_random = self.E(z_random)

            # create modified feature vectors
            fts[0] = fts[0].detach()
            fts[1] = fts[1].detach()
            fts[2] = fts[2].detach()
            fts[3] = fts[3].detach()
            fts_mod = fts.copy()
            shuffled_ids = torch.randperm(len(fts[shuffle_lvl]))
            y_mod = None
            if y is not None:
                if name == 'shape':
                    y_mod = [y[0][shuffled_ids], y[1][shuffled_ids]]
                else:
                    y_mod = y[shuffled_ids]

            fts_mod[shuffle_lvl] = fts[shuffle_lvl][shuffled_ids]

            # predict full cycle
            z_random_mod = self.G(*fts_mod)
            X_random_mod = self.P(z_random_mod)[:,:3]
            z_random_mod_recon = self.Q(X_random_mod)
            fts2 = self.E(z_random_mod_recon)

            # recon error in unmodified part
            # h = torch.cat([fts_mod[i] for i in range(len(fts_mod)) if i != lvl], dim=1)
            # h2 = torch.cat([fts2[i] for i in range(len(fts2)) if i != lvl], dim=1)
            # l1_err_h = torch.abs(h - h2).mean(dim=1)
            # l1_err_h = torch.abs(torch.cat(fts_mod, dim=1) - torch.cat(fts2, dim=1)).mean(dim=1)

            # recon error in modified part
            # l1_err_f = np.rad2deg(to_numpy(torch.abs(fts_mod[lvl] - fts2[lvl]).mean(dim=1)))

            # recon error in entire vector
            l1_err = torch.abs(torch.cat(fts_mod, dim=1)[:,3:] - torch.cat(fts2, dim=1)[:,3:]).mean(dim=1)
            loss_dis_cycle = F.l1_loss(torch.cat(fts_mod, dim=1)[:,3:], torch.cat(fts2, dim=1)[:,3:]) * cfg.W_CYCLE
            iter_stats['loss_dis_cycle'] = loss_dis_cycle.item()

            loss_I += loss_dis_cycle

            # cycle augmentation loss
            if cfg.WITH_AUGMENTATION_LOSS and y_mod is not None:
                y_f_2 = fts2[lvl]
                loss_I_f_2, err_f_2 = calc_feature_loss(name, y_f_2, y_mod, show_triplets=show_triplets, wnd_title='aug')
                loss_I += loss_I_f_2 * cfg.W_AUG
                iter_stats[name+'_loss_f_2'] = loss_I_f_2.item()
                iter_stats[name+'_err_f_2'] = np.mean(err_f_2)

            #
            # Adversarial loss of modified generations
            #

            GAN = False
            if GAN and train:
                eps = 0.00001

                # #######################
                # # GAN discriminator phase
                # #######################
                update_D = False
                if update_D:
                    self.D.zero_grad()
                    err_real = self.D(self.images)
                    err_fake = self.D(X_random_mod.detach())
                    # err_fake = self.D(X_z_recon.detach())
                    loss_D = -torch.mean(torch.log(err_real + eps) + torch.log(1.0 - err_fake + eps)) * 0.1
                    loss_D.backward()
                    self.optimizer_D.step()
                    iter_stats.update({'loss_D': loss_D.item()})

                #######################
                # Generator loss
                #######################
                self.D.zero_grad()
                err_fake = self.D(X_random_mod)
                # err_fake = self.D(X_z_recon)
                loss_G += -torch.mean(torch.log(err_fake + eps))

                iter_stats.update({'loss_G': loss_G.item()})
                # iter_stats.update({'err_real': err_real.mean().item(), 'err_fake': loss_G.mean().item()})

            # debug visualization
            show = True
            if show:
                if (self.iter+1) % self.print_interval in [0,1]:
                    if Y[3] is None:
                        emotion_gt = np.zeros(len(z), dtype=int)
                        emotion_gt_mod = np.zeros(len(z), dtype=int)
                    else:
                        emotion_gt = Y[3][:,0].long()
                        emotion_gt_mod = Y[3][shuffled_ids,0].long()
                    with torch.no_grad():
                        self.znet.eval()
                        self.G.eval()
                        emotion_preds = torch.max(self.znet(fe.detach()), 1)[1]
                        emotion_mod = torch.max(self.znet(fts_mod[3].detach()), 1)[1]
                        emotion_mod_pred = torch.max(self.znet(fts2[3].detach()), 1)[1]
                        X_recon = self.P(z)[:,:3]
                        X_z_recon = self.P(z_recon)[:,:3]
                        X_random_mod_recon = self.P(self.G(*fts2))[:,:3]
                        self.znet.train(train)
                        self.G.train(train)
                        X_recon_errs = 255.0 * torch.abs(self.images - X_recon).reshape(len(self.images), -1).mean(dim=1)
                        X_z_recon_errs = 255.0 * torch.abs(self.images - X_z_recon).reshape(len(self.images), -1).mean(dim=1)

                    nimgs = 8

                    disp_input = vis.add_pose_to_images(ds_utils.denormalized(self.images)[:nimgs], Y[0], color=(0, 0, 1.0))
                    if name == 'expression':
                        disp_input = vis.add_emotion_to_images(disp_input, to_numpy(emotion_gt))
                    elif name == 'id':
                        disp_input = vis.add_id_to_images(disp_input, to_numpy(Y[1]))

                    disp_recon = vis.add_pose_to_images(ds_utils.denormalized(X_recon)[:nimgs], fts[0])
                    disp_recon = vis.add_error_to_images(disp_recon, errors=X_recon_errs, format_string='{:.1f}')

                    disp_z_recon = vis.add_pose_to_images(ds_utils.denormalized(X_z_recon)[:nimgs], fts[0])
                    disp_z_recon = vis.add_emotion_to_images(disp_z_recon, to_numpy(emotion_preds),
                                                             gt_emotions=to_numpy(emotion_gt) if name=='expression' else None)
                    disp_z_recon = vis.add_error_to_images(disp_z_recon, errors=X_z_recon_errs, format_string='{:.1f}')

                    disp_input_shuffle = vis.add_pose_to_images(ds_utils.denormalized(self.images[shuffled_ids])[:nimgs], fts[0][shuffled_ids])
                    disp_input_shuffle = vis.add_emotion_to_images(disp_input_shuffle, to_numpy(emotion_gt_mod))
                    if name == 'id':
                        disp_input_shuffle = vis.add_id_to_images(disp_input_shuffle, to_numpy(Y[1][shuffled_ids]))

                    disp_recon_shuffle = vis.add_pose_to_images(ds_utils.denormalized(X_random_mod)[:nimgs], fts_mod[0], color=(0, 0, 1.0))
                    disp_recon_shuffle = vis.add_emotion_to_images(disp_recon_shuffle, to_numpy(emotion_mod))

                    disp_cycle = vis.add_pose_to_images(ds_utils.denormalized(X_random_mod_recon)[:nimgs], fts2[0])
                    disp_cycle = vis.add_emotion_to_images(disp_cycle, to_numpy(emotion_mod_pred))
                    disp_cycle = vis.add_error_to_images(disp_cycle, errors=l1_err, format_string='{:.3f}',
                                                         size=0.6, thickness=2, vmin=0, vmax=0.1)

                    rows = [
                        # original input images
                        vis.make_grid(disp_input, nCols=nimgs),

                        # reconstructions without disentanglement
                        vis.make_grid(disp_recon, nCols=nimgs),

                        # reconstructions with disentanglement
                        vis.make_grid(disp_z_recon, nCols=nimgs),

                        # source for feature transfer
                        vis.make_grid(disp_input_shuffle, nCols=nimgs),

                        # reconstructions with modified feature vector (direkt)
                        vis.make_grid(disp_recon_shuffle, nCols=nimgs),

                        # reconstructions with modified feature vector (1 iters)
                        vis.make_grid(disp_cycle, nCols=nimgs)
                    ]
                    f = 1.0 / cfg.INPUT_SCALE_FACTOR
                    disp_img = vis.make_grid(rows, nCols=1, normalize=False, fx=f, fy=f)

                    wnd_title = name
                    if self.current_dataset is not None:
                        wnd_title += ' ' + self.current_dataset.__class__.__name__
                    cv2.imshow(wnd_title, cv2.cvtColor(disp_img, cv2.COLOR_RGB2BGR))
                    cv2.waitKey(10)

        loss_I *= cfg.W_DISENT

        iter_stats['loss_disent'] = loss_I.item()

        if train:
            loss_I.backward(retain_graph=True)

        return z_recon, iter_stats, loss_G[0]
예제 #6
0
파일: lmutils.py 프로젝트: PaperID8601/3FR
def visualize_batch(images,
                    landmarks,
                    X_recon,
                    X_lm_hm,
                    lm_preds_max,
                    lm_heatmaps=None,
                    images_mod=None,
                    lm_preds_cnn=None,
                    ds=None,
                    wait=0,
                    ssim_maps=None,
                    landmarks_to_draw=lmcfg.ALL_LANDMARKS,
                    ocular_norm='outer',
                    horizontal=False,
                    f=1.0,
                    overlay_heatmaps_input=False,
                    overlay_heatmaps_recon=False,
                    clean=False):

    gt_color = (0, 255, 0)
    pred_color = (0, 0, 255)

    nimgs = min(10, len(images))
    images = nn.atleast4d(images)[:nimgs]
    nme_per_lm = None
    if landmarks is None:
        # print('num landmarks', lmcfg.NUM_LANDMARKS)
        lm_gt = np.zeros((nimgs, lmcfg.NUM_LANDMARKS, 2))
    else:
        lm_gt = nn.atleast3d(to_numpy(landmarks))[:nimgs]
        nme_per_lm = calc_landmark_nme(lm_gt,
                                       lm_preds_max[:nimgs],
                                       ocular_norm=ocular_norm)
        lm_ssim_errs = 1 - calc_landmark_ssim_score(images, X_recon[:nimgs],
                                                    lm_gt)

    lm_confs = None
    # show landmark heatmaps
    pred_heatmaps = None
    if X_lm_hm is not None:
        pred_heatmaps = to_single_channel_heatmap(to_numpy(X_lm_hm[:nimgs]))
        pred_heatmaps = [
            cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
            for im in pred_heatmaps
        ]
        gt_heatmaps = None
        if lm_heatmaps is not None:
            gt_heatmaps = to_single_channel_heatmap(
                to_numpy(lm_heatmaps[:nimgs]))
            gt_heatmaps = np.array([
                cv2.resize(im,
                           None,
                           fx=f,
                           fy=f,
                           interpolation=cv2.INTER_NEAREST)
                for im in gt_heatmaps
            ])
        show_landmark_heatmaps(pred_heatmaps, gt_heatmaps, nimgs, f=1)
        lm_confs = to_numpy(X_lm_hm).reshape(X_lm_hm.shape[0],
                                             X_lm_hm.shape[1], -1).max(axis=2)

    # resize images for display and scale landmarks accordingly
    lm_preds_max = lm_preds_max[:nimgs] * f
    if lm_preds_cnn is not None:
        lm_preds_cnn = lm_preds_cnn[:nimgs] * f
    lm_gt *= f

    input_images = vis._to_disp_images(images[:nimgs], denorm=True)
    if images_mod is not None:
        disp_images = vis._to_disp_images(images_mod[:nimgs], denorm=True)
    else:
        disp_images = vis._to_disp_images(images[:nimgs], denorm=True)
    disp_images = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in disp_images
    ]

    recon_images = vis._to_disp_images(X_recon[:nimgs], denorm=True)
    disp_X_recon = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in recon_images.copy()
    ]

    # overlay landmarks on input images
    if pred_heatmaps is not None and overlay_heatmaps_input:
        disp_images = [
            vis.overlay_heatmap(disp_images[i], pred_heatmaps[i])
            for i in range(len(pred_heatmaps))
        ]
    if pred_heatmaps is not None and overlay_heatmaps_recon:
        disp_X_recon = [
            vis.overlay_heatmap(disp_X_recon[i], pred_heatmaps[i])
            for i in range(len(pred_heatmaps))
        ]

    #
    # Show input images
    #
    disp_images = vis.add_landmarks_to_images(disp_images,
                                              lm_gt[:nimgs],
                                              color=gt_color)
    disp_images = vis.add_landmarks_to_images(disp_images,
                                              lm_preds_max[:nimgs],
                                              lm_errs=nme_per_lm,
                                              color=pred_color,
                                              draw_wireframe=False,
                                              gt_landmarks=lm_gt,
                                              draw_gt_offsets=True)

    # disp_images = vis.add_landmarks_to_images(disp_images, lm_gt[:nimgs], color=(1,1,1), radius=1,
    #                                           draw_dots=True, draw_wireframe=True, landmarks_to_draw=landmarks_to_draw)
    # disp_images = vis.add_landmarks_to_images(disp_images, lm_preds_max[:nimgs], lm_errs=nme_per_lm,
    #                                           color=(1.0, 0.0, 0.0),
    #                                           draw_dots=True, draw_wireframe=True, radius=1,
    #                                           gt_landmarks=lm_gt, draw_gt_offsets=False,
    #                                           landmarks_to_draw=landmarks_to_draw)

    #
    # Show reconstructions
    #
    X_recon_errs = 255.0 * torch.abs(images - X_recon[:nimgs]).reshape(
        len(images), -1).mean(dim=1)
    if not clean:
        disp_X_recon = vis.add_error_to_images(disp_X_recon[:nimgs],
                                               errors=X_recon_errs,
                                               format_string='{:>4.1f}')

    # modes of heatmaps
    # disp_X_recon = [overlay_heatmap(disp_X_recon[i], pred_heatmaps[i]) for i in range(len(pred_heatmaps))]
    if not clean:
        lm_errs_max = calc_landmark_nme_per_img(
            lm_gt,
            lm_preds_max,
            ocular_norm=ocular_norm,
            landmarks_to_eval=lmcfg.LANDMARKS_NO_OUTLINE)
        lm_errs_max_outline = calc_landmark_nme_per_img(
            lm_gt,
            lm_preds_max,
            ocular_norm=ocular_norm,
            landmarks_to_eval=lmcfg.LANDMARKS_ONLY_OUTLINE)
        lm_errs_max_all = calc_landmark_nme_per_img(
            lm_gt,
            lm_preds_max,
            ocular_norm=ocular_norm,
            landmarks_to_eval=lmcfg.ALL_LANDMARKS)
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               lm_errs_max,
                                               loc='br-2',
                                               format_string='{:>5.2f}',
                                               vmax=15)
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               lm_errs_max_outline,
                                               loc='br-1',
                                               format_string='{:>5.2f}',
                                               vmax=15)
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               lm_errs_max_all,
                                               loc='br',
                                               format_string='{:>5.2f}',
                                               vmax=15)
    disp_X_recon = vis.add_landmarks_to_images(disp_X_recon,
                                               lm_gt,
                                               color=gt_color,
                                               draw_wireframe=True)

    # disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, lm_preds_max[:nimgs],
    #                                            color=pred_color, draw_wireframe=False,
    #                                            lm_errs=nme_per_lm, lm_confs=lm_confs,
    #                                            lm_rec_errs=lm_ssim_errs, gt_landmarks=lm_gt,
    #                                            draw_gt_offsets=True, draw_dots=True)

    disp_X_recon = vis.add_landmarks_to_images(disp_X_recon,
                                               lm_preds_max[:nimgs],
                                               color=pred_color,
                                               draw_wireframe=True,
                                               gt_landmarks=lm_gt,
                                               draw_gt_offsets=True,
                                               lm_errs=nme_per_lm,
                                               draw_dots=True,
                                               radius=2)

    def add_confs(disp_X_recon, lmids, loc):
        means = lm_confs[:, lmids].mean(axis=1)
        colors = vis.color_map(to_numpy(1 - means),
                               cmap=plt.cm.jet,
                               vmin=0.0,
                               vmax=0.4)
        return vis.add_error_to_images(disp_X_recon,
                                       means,
                                       loc=loc,
                                       format_string='{:>4.2f}',
                                       colors=colors)

    # disp_X_recon = add_confs(disp_X_recon, lmcfg.LANDMARKS_NO_OUTLINE, 'bm-2')
    # disp_X_recon = add_confs(disp_X_recon, lmcfg.LANDMARKS_ONLY_OUTLINE, 'bm-1')
    # disp_X_recon = add_confs(disp_X_recon, lmcfg.ALL_LANDMARKS, 'bm')

    # print ssim errors
    ssim = np.zeros(nimgs)
    for i in range(nimgs):
        ssim[i] = compare_ssim(input_images[i],
                               recon_images[i],
                               data_range=1.0,
                               multichannel=True)
    if not clean:
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               1 - ssim,
                                               loc='bl-1',
                                               format_string='{:>4.2f}',
                                               vmax=0.8,
                                               vmin=0.2)
    # print ssim torch errors
    if ssim_maps is not None and not clean:
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               ssim_maps.reshape(
                                                   len(ssim_maps),
                                                   -1).mean(axis=1),
                                               loc='bl-2',
                                               format_string='{:>4.2f}',
                                               vmin=0.0,
                                               vmax=0.4)

    rows = [vis.make_grid(disp_images, nCols=nimgs, normalize=False)]
    rows.append(vis.make_grid(disp_X_recon, nCols=nimgs))

    if ssim_maps is not None:
        disp_ssim_maps = to_numpy(
            ds_utils.denormalized(ssim_maps)[:nimgs].transpose(0, 2, 3, 1))
        for i in range(len(disp_ssim_maps)):
            disp_ssim_maps[i] = vis.color_map(disp_ssim_maps[i].mean(axis=2),
                                              vmin=0.0,
                                              vmax=2.0)
        grid_ssim_maps = vis.make_grid(disp_ssim_maps, nCols=nimgs, fx=f, fy=f)
        cv2.imshow('ssim errors',
                   cv2.cvtColor(grid_ssim_maps, cv2.COLOR_RGB2BGR))

    if horizontal:
        assert (nimgs == 1)
        disp_rows = vis.make_grid(rows, nCols=2)
    else:
        disp_rows = vis.make_grid(rows, nCols=1)
    wnd_title = 'Predicted Landmarks '
    if ds is not None:
        wnd_title += ds.__class__.__name__
    cv2.imshow(wnd_title, cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR))
    cv2.waitKey(wait)
예제 #7
0
def visualize_batch(batch,
                    X_recon,
                    X_lm_hm,
                    lm_preds_max,
                    lm_preds_cnn=None,
                    ds=None,
                    wait=0,
                    ssim_maps=None,
                    landmarks_to_draw=lmcfg.LANDMARKS_TO_EVALUATE,
                    ocular_norm='pupil',
                    horizontal=False,
                    f=1.0):

    nimgs = min(10, len(batch))
    gt_color = (0, 1, 0)

    lm_confs = None
    # show landmark heatmaps
    pred_heatmaps = None
    if X_lm_hm is not None:
        pred_heatmaps = to_single_channel_heatmap(to_numpy(X_lm_hm[:nimgs]))
        pred_heatmaps = [
            cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
            for im in pred_heatmaps
        ]
        if batch.lm_heatmaps is not None:
            gt_heatmaps = to_single_channel_heatmap(
                to_numpy(batch.lm_heatmaps[:nimgs]))
            gt_heatmaps = np.array([
                cv2.resize(im,
                           None,
                           fx=f,
                           fy=f,
                           interpolation=cv2.INTER_NEAREST)
                for im in gt_heatmaps
            ])
            show_landmark_heatmaps(pred_heatmaps, gt_heatmaps, nimgs, f=1)
        lm_confs = to_numpy(X_lm_hm).reshape(X_lm_hm.shape[0],
                                             X_lm_hm.shape[1], -1).max(axis=2)

    # scale landmarks
    lm_preds_max = lm_preds_max[:nimgs] * f
    if lm_preds_cnn is not None:
        lm_preds_cnn = lm_preds_cnn[:nimgs] * f
    lm_gt = to_numpy(batch.landmarks[:nimgs]) * f
    if lm_gt.shape[1] == 98:
        lm_gt = convert_landmarks(lm_gt, LM98_TO_LM68)

    input_images = vis._to_disp_images(batch.images[:nimgs], denorm=True)
    if batch.images_mod is not None:
        disp_images = vis._to_disp_images(batch.images_mod[:nimgs],
                                          denorm=True)
    else:
        disp_images = vis._to_disp_images(batch.images[:nimgs], denorm=True)
    disp_images = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in disp_images
    ]

    recon_images = vis._to_disp_images(X_recon[:nimgs], denorm=True)
    disp_X_recon = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in recon_images.copy()
    ]

    # draw landmarks to input images
    if pred_heatmaps is not None:
        disp_images = [
            vis.overlay_heatmap(disp_images[i], pred_heatmaps[i])
            for i in range(len(pred_heatmaps))
        ]

    nme_per_lm = calc_landmark_nme(lm_gt,
                                   lm_preds_max,
                                   ocular_norm=ocular_norm)
    lm_ssim_errs = calc_landmark_ssim_error(batch.images[:nimgs],
                                            X_recon[:nimgs],
                                            batch.landmarks[:nimgs])

    #
    # Show input images
    #
    disp_images = vis.add_landmarks_to_images(
        disp_images,
        lm_gt[:nimgs],
        color=gt_color,
        draw_dots=True,
        draw_wireframe=False,
        landmarks_to_draw=landmarks_to_draw)
    disp_images = vis.add_landmarks_to_images(
        disp_images,
        lm_preds_max[:nimgs],
        lm_errs=nme_per_lm,
        color=(0.0, 0.0, 1.0),
        draw_dots=True,
        draw_wireframe=False,
        gt_landmarks=lm_gt,
        draw_gt_offsets=True,
        landmarks_to_draw=landmarks_to_draw)

    # if lm_preds_cnn is not None:
    #     disp_images = vis.add_landmarks_to_images(disp_images, lm_preds_cnn, color=(1, 1, 0),
    #                                               gt_landmarks=lm_gt, draw_gt_offsets=False,
    #                                               draw_wireframe=True, landmarks_to_draw=landmarks_to_draw)

    rows = [vis.make_grid(disp_images, nCols=nimgs, normalize=False)]

    #
    # Show reconstructions
    #
    X_recon_errs = 255.0 * torch.abs(batch.images - X_recon).reshape(
        len(batch.images), -1).mean(dim=1)
    disp_X_recon = vis.add_error_to_images(disp_X_recon[:nimgs],
                                           errors=X_recon_errs,
                                           format_string='{:>4.1f}')

    # modes of heatmaps
    # disp_X_recon = [overlay_heatmap(disp_X_recon[i], pred_heatmaps[i]) for i in range(len(pred_heatmaps))]
    lm_errs_max = calc_landmark_nme_per_img(
        lm_gt,
        lm_preds_max,
        ocular_norm=ocular_norm,
        landmarks_to_eval=lmcfg.LANDMARKS_NO_OUTLINE)
    lm_errs_max_outline = calc_landmark_nme_per_img(
        lm_gt,
        lm_preds_max,
        ocular_norm=ocular_norm,
        landmarks_to_eval=lmcfg.LANDMARKS_ONLY_OUTLINE)
    lm_errs_max_all = calc_landmark_nme_per_img(
        lm_gt,
        lm_preds_max,
        ocular_norm=ocular_norm,
        landmarks_to_eval=lmcfg.ALL_LANDMARKS)
    disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                           lm_errs_max,
                                           loc='br-2',
                                           format_string='{:>5.2f}',
                                           vmax=15)
    disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                           lm_errs_max_outline,
                                           loc='br-1',
                                           format_string='{:>5.2f}',
                                           vmax=15)
    disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                           lm_errs_max_all,
                                           loc='br',
                                           format_string='{:>5.2f}',
                                           vmax=15)
    disp_X_recon = vis.add_landmarks_to_images(
        disp_X_recon,
        lm_preds_max[:nimgs],
        color=(0, 0, 1),
        landmarks_to_draw=landmarks_to_draw,
        draw_wireframe=False,
        lm_errs=nme_per_lm,
        # lm_confs=lm_confs,
        lm_confs=1 - lm_ssim_errs,
        gt_landmarks=lm_gt,
        draw_gt_offsets=True,
        draw_dots=True)
    disp_X_recon = vis.add_landmarks_to_images(
        disp_X_recon,
        lm_gt,
        color=gt_color,
        draw_wireframe=False,
        landmarks_to_draw=landmarks_to_draw)

    # landmarks from CNN prediction
    if lm_preds_cnn is not None:
        nme_per_lm = calc_landmark_nme(lm_gt,
                                       lm_preds_cnn,
                                       ocular_norm=ocular_norm)
        disp_X_recon = vis.add_landmarks_to_images(
            disp_X_recon,
            lm_preds_cnn,
            color=(1, 1, 0),
            landmarks_to_draw=lmcfg.ALL_LANDMARKS,
            draw_wireframe=False,
            lm_errs=nme_per_lm,
            gt_landmarks=lm_gt,
            draw_gt_offsets=True,
            draw_dots=True,
            offset_line_color=(1, 1, 1))
        lm_errs_cnn = calc_landmark_nme_per_img(
            lm_gt,
            lm_preds_cnn,
            ocular_norm=ocular_norm,
            landmarks_to_eval=landmarks_to_draw)
        lm_errs_cnn_outline = calc_landmark_nme_per_img(
            lm_gt,
            lm_preds_cnn,
            ocular_norm=ocular_norm,
            landmarks_to_eval=lmcfg.LANDMARKS_ONLY_OUTLINE)
        lm_errs_cnn_all = calc_landmark_nme_per_img(
            lm_gt,
            lm_preds_cnn,
            ocular_norm=ocular_norm,
            landmarks_to_eval=lmcfg.ALL_LANDMARKS)
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               lm_errs_cnn,
                                               loc='tr',
                                               format_string='{:>5.2f}',
                                               vmax=15)
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               lm_errs_cnn_outline,
                                               loc='tr+1',
                                               format_string='{:>5.2f}',
                                               vmax=15)
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               lm_errs_cnn_all,
                                               loc='tr+2',
                                               format_string='{:>5.2f}',
                                               vmax=15)

    # print ssim errors
    ssim = np.zeros(nimgs)
    for i in range(nimgs):
        ssim[i] = compare_ssim(input_images[i],
                               recon_images[i],
                               data_range=1.0,
                               multichannel=True)
    disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                           1 - ssim,
                                           loc='bl-1',
                                           format_string='{:>4.2f}',
                                           vmax=0.8,
                                           vmin=0.2)
    # print ssim torch errors
    if ssim_maps is not None:
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               ssim_maps.reshape(
                                                   len(ssim_maps),
                                                   -1).mean(axis=1),
                                               loc='bl-2',
                                               format_string='{:>4.2f}',
                                               vmin=0.0,
                                               vmax=0.4)

    rows.append(vis.make_grid(disp_X_recon, nCols=nimgs))

    if ssim_maps is not None:
        disp_ssim_maps = to_numpy(
            ds_utils.denormalized(ssim_maps)[:nimgs].transpose(0, 2, 3, 1))
        for i in range(len(disp_ssim_maps)):
            disp_ssim_maps[i] = vis.color_map(disp_ssim_maps[i].mean(axis=2),
                                              vmin=0.0,
                                              vmax=2.0)
        grid_ssim_maps = vis.make_grid(disp_ssim_maps, nCols=nimgs, fx=f, fy=f)
        cv2.imshow('ssim errors',
                   cv2.cvtColor(grid_ssim_maps, cv2.COLOR_RGB2BGR))

    X_gen_lm_hm = None
    X_gen_vis = None
    show_random_faces = False
    if show_random_faces:
        with torch.no_grad():
            z_random = self.enc_rand(nimgs, self.saae.z_dim).cuda()
            outputs = self.saae.P(z_random)
            X_gen_vis = outputs[:, :3]
            if outputs.shape[1] > 3:
                X_gen_lm_hm = outputs[:, 3:]
        disp_X_gen = to_numpy(
            ds_utils.denormalized(X_gen_vis)[:nimgs].permute(0, 2, 3, 1))

        if X_gen_lm_hm is not None:
            if lmcfg.LANDMARK_TARGET == 'colored':
                gen_heatmaps = [to_image(X_gen_lm_hm[i]) for i in range(nimgs)]
            elif lmcfg.LANDMARK_TARGET == 'multi_channel':
                X_gen_lm_hm = X_gen_lm_hm.max(dim=1)[0]
                gen_heatmaps = [to_image(X_gen_lm_hm[i]) for i in range(nimgs)]
            else:
                gen_heatmaps = [
                    to_image(X_gen_lm_hm[i, 0]) for i in range(nimgs)
                ]

            disp_X_gen = [
                vis.overlay_heatmap(disp_X_gen[i], gen_heatmaps[i])
                for i in range(len(pred_heatmaps))
            ]

            # inputs = torch.cat([X_gen_vis, X_gen_lm_hm.detach()], dim=1)
            inputs = X_gen_lm_hm.detach()

            # disabled for multi_channel LM targets
            # lm_gen_preds = self.saae.lm_coords(inputs).reshape(len(inputs), -1, 2)
            # disp_X_gen = vis.add_landmarks_to_images(disp_X_gen, lm_gen_preds[:nimgs], color=(0,1,1))

            disp_gen_heatmaps = [
                vis.color_map(hm, vmin=0, vmax=1.0) for hm in gen_heatmaps
            ]
            img_gen_heatmaps = cv2.resize(vis.make_grid(disp_gen_heatmaps,
                                                        nCols=nimgs,
                                                        padval=0),
                                          None,
                                          fx=1.0,
                                          fy=1.0)
            cv2.imshow('generated landmarks',
                       cv2.cvtColor(img_gen_heatmaps, cv2.COLOR_RGB2BGR))

        rows.append(vis.make_grid(disp_X_gen, nCols=nimgs))

    # self.saae.D.train(train_state_D)
    # self.saae.Q.train(train_state_Q)
    # self.saae.P.train(train_state_P)

    if horizontal:
        assert (nimgs == 1)
        disp_rows = vis.make_grid(rows, nCols=2)
    else:
        disp_rows = vis.make_grid(rows, nCols=1)
    wnd_title = 'recon errors '
    if ds is not None:
        wnd_title += ds.__class__.__name__
    cv2.imshow(wnd_title, cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR))
    cv2.waitKey(wait)
예제 #8
0
def draw_results(X_resized,
                 X_recon,
                 levels_z=None,
                 landmarks=None,
                 landmarks_pred=None,
                 cs_errs=None,
                 ncols=15,
                 fx=0.5,
                 fy=0.5,
                 additional_status_text=''):

    clean_images = True
    if clean_images:
        landmarks = None

    nimgs = len(X_resized)
    ncols = min(ncols, nimgs)
    img_size = X_recon.shape[-1]

    disp_X = vis.to_disp_images(X_resized, denorm=True)
    disp_X_recon = vis.to_disp_images(X_recon, denorm=True)

    # reconstruction error in pixels
    l1_dists = 255.0 * to_numpy(
        (X_resized - X_recon).abs().reshape(len(disp_X), -1).mean(dim=1))

    # SSIM errors
    ssim = np.zeros(nimgs)
    for i in range(nimgs):
        ssim[i] = compare_ssim(disp_X[i],
                               disp_X_recon[i],
                               data_range=1.0,
                               multichannel=True)

    landmarks = to_numpy(landmarks)
    cs_errs = to_numpy(cs_errs)

    text_size = img_size / 256
    text_thickness = 2

    #
    # Visualise resized input images and reconstructed images
    #
    if landmarks is not None:
        disp_X = vis.add_landmarks_to_images(
            disp_X,
            landmarks,
            draw_wireframe=False,
            landmarks_to_draw=lmcfg.LANDMARKS_19)
        disp_X_recon = vis.add_landmarks_to_images(
            disp_X_recon,
            landmarks,
            draw_wireframe=False,
            landmarks_to_draw=lmcfg.LANDMARKS_19)

    if landmarks_pred is not None:
        disp_X = vis.add_landmarks_to_images(disp_X,
                                             landmarks_pred,
                                             color=(1, 0, 0))
        disp_X_recon = vis.add_landmarks_to_images(disp_X_recon,
                                                   landmarks_pred,
                                                   color=(1, 0, 0))

    if not clean_images:
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               l1_dists,
                                               format_string='{:.1f}',
                                               size=text_size,
                                               thickness=text_thickness)
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               1 - ssim,
                                               loc='bl-1',
                                               format_string='{:>4.2f}',
                                               vmax=0.8,
                                               vmin=0.2,
                                               size=text_size,
                                               thickness=text_thickness)
        if cs_errs is not None:
            disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                                   cs_errs,
                                                   loc='bl-2',
                                                   format_string='{:>4.2f}',
                                                   vmax=0.0,
                                                   vmin=0.4,
                                                   size=text_size,
                                                   thickness=text_thickness)

    # landmark errors
    lm_errs = np.zeros(1)
    if landmarks is not None:
        try:
            from landmarks import lmutils
            lm_errs = lmutils.calc_landmark_nme_per_img(
                landmarks, landmarks_pred)
            disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                                   lm_errs,
                                                   loc='br',
                                                   format_string='{:>5.2f}',
                                                   vmax=15,
                                                   size=img_size / 256,
                                                   thickness=2)
        except:
            pass

    img_input = vis.make_grid(disp_X, nCols=ncols, normalize=False)
    img_recon = vis.make_grid(disp_X_recon, nCols=ncols, normalize=False)
    img_input = cv2.resize(img_input,
                           None,
                           fx=fx,
                           fy=fy,
                           interpolation=cv2.INTER_CUBIC)
    img_recon = cv2.resize(img_recon,
                           None,
                           fx=fx,
                           fy=fy,
                           interpolation=cv2.INTER_CUBIC)

    img_stack = [img_input, img_recon]

    #
    # Visualise hidden layers
    #
    VIS_HIDDEN = False
    if VIS_HIDDEN:
        img_z = vis.draw_z_vecs(levels_z, size=(img_size, 30), ncols=ncols)
        img_z = cv2.resize(img_z,
                           dsize=(img_input.shape[1], img_z.shape[0]),
                           interpolation=cv2.INTER_NEAREST)
        img_stack.append(img_z)

    cs_errs_mean = np.mean(cs_errs) if cs_errs is not None else np.nan
    status_bar_text = ("l1 recon err: {:.2f}px  "
                       "ssim: {:.3f}({:.3f})  "
                       "lms err: {:2} {}").format(l1_dists.mean(),
                                                  cs_errs_mean,
                                                  1 - ssim.mean(),
                                                  lm_errs.mean(),
                                                  additional_status_text)

    img_status_bar = vis.draw_status_bar(status_bar_text,
                                         status_bar_width=img_input.shape[1],
                                         status_bar_height=30,
                                         dtype=img_input.dtype)
    img_stack.append(img_status_bar)

    return np.vstack(img_stack)
예제 #9
0
파일: nn.py 프로젝트: browatbn2/MAD
def calc_triplet_loss(outputs,
                      c,
                      return_acc=False,
                      images=None,
                      feature_name=None,
                      wnd_title=None):

    margin = 0.2
    eps = 1e-8

    debug = False
    is_expressions = (not isinstance(
        c, list)) and len(c.shape) > 1 and c.shape[1] == 3

    pos_id, neg_id = make_triplets(outputs, c, debug=debug)

    X, P, N = outputs[:, :], outputs[pos_id, :], outputs[neg_id, :]
    dpos = torch.sqrt(torch.sum((X - P)**2, dim=1) + eps)
    dneg = torch.sqrt(torch.sum((X - N)**2, dim=1) + eps)
    loss = torch.mean(
        torch.clamp(dpos - dneg + margin, min=0.0, max=margin * 2.0))
    # show triplets
    if images is not None:
        from utils import vis
        from datasets import ds_utils
        if debug and is_expressions:
            for i in range(10):
                print(c[:, 0][i].item(), c[pos_id, 0][i].item(),
                      c[neg_id, 0][i].item())
        # ids, vids = c[0], c[1]
        # print(vids[:5])
        # print(vids[pos_id][:5])
        # print(vids[neg_id][:5])
        nimgs = 20
        losses = to_numpy(
            torch.clamp(dpos - dneg + margin, min=0.0, max=margin * 2.0))
        # print("Acc: ", 1 - sum(dpos[:nimgs] >= dneg[:nimgs]).item()/float(len(dpos[:nimgs])))
        # print("L  : ", losses.mean())
        images_ref = ds_utils.denormalized(images[:nimgs].clone())
        images_pos = ds_utils.denormalized(images[pos_id][:nimgs].clone())
        images_neg = ds_utils.denormalized(images[neg_id][:nimgs].clone())
        colors = [(0, 1, 0) if c else (1, 0, 0) for c in dpos < dneg]
        f = 0.75
        images_ref = vis.add_error_to_images(vis.add_cirle_to_images(
            images_ref, colors),
                                             losses,
                                             size=1.0,
                                             vmin=0,
                                             vmax=0.5,
                                             thickness=2,
                                             format_string='{:.2f}')
        images_pos = vis.add_error_to_images(images_pos,
                                             to_numpy(dpos),
                                             size=1.0,
                                             vmin=0.5,
                                             vmax=1.0,
                                             thickness=2,
                                             format_string='{:.2f}')
        images_neg = vis.add_error_to_images(images_neg,
                                             to_numpy(dneg),
                                             size=1.0,
                                             vmin=0.5,
                                             vmax=1.0,
                                             thickness=2,
                                             format_string='{:.2f}')
        if is_expressions:
            emotions = to_numpy(c[:, 0]).astype(int)
            images_ref = vis.add_emotion_to_images(images_ref, emotions)
            images_pos = vis.add_emotion_to_images(images_pos,
                                                   emotions[pos_id])
            images_neg = vis.add_emotion_to_images(images_neg,
                                                   emotions[neg_id])
        elif feature_name == 'id':
            ids = to_numpy(c).astype(int)
            images_ref = vis.add_id_to_images(images_ref, ids, loc='tr')
            images_pos = vis.add_id_to_images(images_pos,
                                              ids[pos_id],
                                              loc='tr')
            images_neg = vis.add_id_to_images(images_neg,
                                              ids[neg_id],
                                              loc='tr')

        img_ref = vis.make_grid(images_ref, nCols=nimgs, padsize=1, fx=f, fy=f)
        img_pos = vis.make_grid(images_pos, nCols=nimgs, padsize=1, fx=f, fy=f)
        img_neg = vis.make_grid(images_neg, nCols=nimgs, padsize=1, fx=f, fy=f)
        title = 'triplets'
        if feature_name is not None:
            title += " " + feature_name
        if wnd_title is not None:
            title += " " + wnd_title
        vis.vis_square([img_ref, img_pos, img_neg],
                       nCols=1,
                       padsize=1,
                       normalize=False,
                       wait=10,
                       title=title)

        # plt.plot(to_numpy((X[:nimgs]-P[:nimgs]).abs()), 'b')
        # plt.plot(to_numpy((X[:nimgs]-N[:nimgs]).abs()), 'r')
        # plt.show()
    if return_acc:
        return loss, sum(dpos >= dneg).item() / float(len(dpos))
    else:
        return loss