コード例 #1
0
    def heatmaps_to_landmarks(self, hms):
        lms = np.zeros((len(hms), self.num_landmarks, 2), dtype=int)
        if hms.shape[1] > 3:
            # print(hms.max())
            for i in range(len(hms)):
                heatmaps = to_numpy(hms[i])
                for l in range(len(heatmaps)):
                    hm = heatmaps[self.landmark_id_to_heatmap_id(l)]
                    lms[i, l, :] = np.unravel_index(np.argmax(hm, axis=None),
                                                    hm.shape)[::-1]
        elif hms.shape[1] == 3:
            hms = to_numpy(hms)

            def get_score_plane(h, lm_id, cn):
                v = nn.lmcolors[lm_id, cn]
                hcn = h[cn]
                hcn[hcn < v - 2] = 0
                hcn[hcn > v + 5] = 0
                return hcn

            hms *= 255
            for i in range(len(hms)):
                hm = hms[i]
                for l in landmarks.config.LANDMARKS:
                    lm_score_map = get_score_plane(hm, l, 0) * get_score_plane(
                        hm, l, 1) * get_score_plane(hm, l, 2)
                    lms[i, l, :] = np.unravel_index(
                        np.argmax(lm_score_map, axis=None),
                        lm_score_map.shape)[::-1]
        lm_scale = lmcfg.HEATMAP_SIZE / self.input_size
        return lms / lm_scale
コード例 #2
0
def show_images_in_batch(images, labels):
    from datasets.emotiw import EmotiW
    disp_imgs = to_numpy(images)
    if labels is not None:
        labels = to_numpy(labels)
        disp_imgs = add_frames_to_images(images, labels, label_colors=EmotiW.colors_rgb)
    vis_square(disp_imgs, fx=0.4, fy=0.4)
コード例 #3
0
def calc_landmark_recon_error(X, X_recon, lms, return_maps=False, reduction="mean"):
    assert len(X.shape) == 4
    assert reduction in ["mean", "none"]
    X = to_numpy(X)
    X_recon = to_numpy(X_recon)
    mask = np.zeros((X.shape[0], X.shape[2], X.shape[3]), dtype=np.float32)
    input_size = X.shape[-1]
    radius = input_size * 0.05
    for img_id in range(len(mask)):
        for lm in lms[img_id]:
            cv2.circle(
                mask[img_id],
                (int(lm[0]), int(lm[1])),
                radius=int(radius),
                color=1,
                thickness=-1,
            )
    err_maps = np.abs(X - X_recon).mean(axis=1) * 255.0
    masked_err_maps = err_maps * mask

    debug = False
    if debug:
        fig, ax = plt.subplots(1, 3)
        ax[0].imshow(
            vis.to_disp_image(
                (X * mask[:, np.newaxis, :, :].repeat(3, axis=1))[0], denorm=True
            )
        )
        ax[1].imshow(
            vis.to_disp_image(
                (X_recon * mask[:, np.newaxis, :, :].repeat(3, axis=1))[0], denorm=True
            )
        )
        ax[2].imshow(masked_err_maps[0])
        plt.show()

    if reduction == "mean":
        err = masked_err_maps.sum() / (mask.sum() * 3)
    else:
        # err = masked_err_maps.mean(axis=2).mean(axis=1)
        err = masked_err_maps.sum(axis=2).sum(axis=1) / (
            mask.reshape(len(mask), -1).sum(axis=1) * 3
        )

    if return_maps:
        return err, masked_err_maps
    else:
        return err
コード例 #4
0
 def res_vec(self, exclude_fid):
     if cfg.WITH_PARALLEL_DISENTANGLEMENT:
         try:
             h = torch.cat([self.f_parallel[i] for i in range(len(self.f_parallel)) if i != exclude_fid], dim=1)
             return to_numpy(h)
         except:
             return None
コード例 #5
0
def add_error_to_images(images, errors, loc='bl', size=0.65, vmin=0., vmax=30.0, thickness=1, format_string='{:.1f}'):
    new_images = _to_disp_images(images)
    colors = color_map(to_numpy(errors), cmap=plt.cm.jet, vmin=vmin, vmax=vmax)
    for disp, err, color in zip(new_images, errors, colors):
        pos = get_pos_in_image(loc, size, disp.shape)
        cv2.putText(disp, format_string.format(err), pos, cv2.FONT_HERSHEY_DUPLEX, size, color, thickness, cv2.LINE_AA)
    return new_images
コード例 #6
0
 def detect_landmarks(self, X):
     X_recon = self.forward(X)
     X_lm_hm = self.LMH(self.P)
     X_lm_hm = landmarks.lmutils.decode_heatmap_blob(X_lm_hm)
     X_lm_hm = landmarks.lmutils.smooth_heatmaps(X_lm_hm)
     lm_preds = to_numpy(self.heatmaps_to_landmarks(X_lm_hm))
     return X_recon, lm_preds, X_lm_hm
コード例 #7
0
 def emotions_pred(self):
     try:
         if self.emotion_probs is not None:
             return np.argmax(to_numpy(self.emotion_probs), axis=1)
     except AttributeError:
         pass
     return None
コード例 #8
0
def add_error_to_images(
    images,
    errors,
    loc="bl",
    size=0.65,
    vmin=0.0,
    vmax=30.0,
    thickness=1,
    format_string="{:.1f}",
    colors=None,
):
    new_images = to_disp_images(images)
    if colors is None:
        colors = color_map(to_numpy(errors),
                           cmap=plt.cm.jet,
                           vmin=vmin,
                           vmax=vmax)
        if images[0].dtype == np.uint8:
            colors *= 255
    for disp, err, color in zip(new_images, errors, colors):
        pos = get_pos_in_image(loc, size, disp.shape)
        cv2.putText(
            disp,
            format_string.format(err),
            pos,
            cv2.FONT_HERSHEY_DUPLEX,
            size,
            color,
            thickness,
            cv2.LINE_AA,
        )
    return new_images
コード例 #9
0
ファイル: lmutils.py プロジェクト: PaperID8601/3FR
def visualize_random_faces(net, nimgs=10, wait=10, f=1.0):
    z_random = torch.randn(nimgs, net.z_dim).cuda()
    with torch.no_grad():
        X_gen_vis = net.P(z_random)[:, :3]
        X_lm_hm = net.LMH(net.P)
    pred_heatmaps = to_single_channel_heatmap(to_numpy(X_lm_hm[:nimgs]))
    pred_heatmaps = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in pred_heatmaps
    ]
    disp_X_gen = to_numpy(ds_utils.denormalized(X_gen_vis).permute(0, 2, 3, 1))
    disp_X_gen = (disp_X_gen * 255).astype(np.uint8)
    # disp_X_gen = [vis.overlay_heatmap(disp_X_gen[i], pred_heatmaps[i]) for i in range(len(pred_heatmaps))]

    grid_img = vis.make_grid(disp_X_gen, nCols=nimgs // 2)
    cv2.imshow("random faces", cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR))
    cv2.waitKey(wait)
コード例 #10
0
def smooth_heatmaps(hms):
    assert len(hms.shape) == 4
    hms = to_numpy(hms)
    for i in range(hms.shape[0]):
        for l in range(hms.shape[1]):
            hms[i, l] = cv2.blur(hms[i, l], (9, 9), borderType=cv2.BORDER_CONSTANT)
            # hms[i,l] = cv2.GaussianBlur(hms[i,l], (9,9), sigmaX=9, borderType=cv2.BORDER_CONSTANT)
    return hms
コード例 #11
0
ファイル: lmutils.py プロジェクト: PaperID8601/3FR
def decode_heatmap_blob(hms):
    assert len(hms.shape) == 4
    if hms.shape[1] == lmcfg.NUM_LANDMARK_HEATMAPS:  # no decoding necessary
        return hms
    assert hms.shape[1] == len(hm_code_mat)
    hms68 = np.zeros((hms.shape[0], 68, hms.shape[2], hms.shape[3]),
                     dtype=np.float32)
    for img_idx in range(len(hms)):
        hms68[img_idx] = decode_heatmaps(to_numpy(hms[img_idx]))[0]
    return hms68
コード例 #12
0
ファイル: lmutils.py プロジェクト: PaperID8601/3FR
 def add_confs(disp_X_recon, lmids, loc):
     means = lm_confs[:, lmids].mean(axis=1)
     colors = vis.color_map(to_numpy(1 - means),
                            cmap=plt.cm.jet,
                            vmin=0.0,
                            vmax=0.4)
     return vis.add_error_to_images(disp_X_recon,
                                    means,
                                    loc=loc,
                                    format_string='{:>4.2f}',
                                    colors=colors)
コード例 #13
0
    def __forward_disentanglement_parallel(self, z, Y=None):
        iter_stats = {}

        with torch.no_grad():
            self.f_parallel = self.E(z)
            z_recon = self.G(*self.f_parallel)

            ft_id = 3
            try:
                y = Y[ft_id]
            except TypeError:
                y = None

            y_f = self.f_parallel[3]

            try:
                y_p = self.f_parallel[0]
                def calc_err(outputs, target):
                    return np.abs(np.rad2deg(F.l1_loss(outputs, target, reduction='none').detach().cpu().numpy()))
                iter_stats['err_pose'] = calc_err(y_p, Y[0])
            except TypeError:
                pass


            clprobs = self.znet(y_f)
            self.emotion_probs = clprobs

            if y is not None:
                emotion_labels = y[:, 0].long()
                loss_cls = self.cross_entropy_loss(clprobs, emotion_labels)
                acc_cls = calc_acc(clprobs, emotion_labels)
                iter_stats['loss_cls'] = loss_cls.item()
                iter_stats['acc_cls'] = acc_cls
                iter_stats['emotion_probs'] = to_numpy(clprobs)
                iter_stats['emotion_labels'] = to_numpy(emotion_labels)
                f_parallel_recon = self.E(self.Q(self.P(z_recon)[:,:3]))
                l1_err = torch.abs(torch.cat(f_parallel_recon, dim=1) - torch.cat(self.f_parallel, dim=1)).mean(dim=1)
                iter_stats['l1_dis_cycle'] = to_numpy(l1_err)

        return z_recon, iter_stats, None
コード例 #14
0
    def heatmaps_to_landmarks(self, hms):
        lms = np.zeros((len(hms), lmcfg.NUM_LANDMARKS, 2), dtype=int)
        if hms.shape[1] > 3:
            # print(hms.max())
            for i in range(len(hms)):
                if hms.shape[1] not in [19, 68, 98]:
                    _, lm_coords = landmarks.lmutils.decode_heatmaps(
                        to_numpy(hms[i]))
                    lms[i] = lm_coords
                else:
                    heatmaps = to_numpy(hms[i])
                    for l in range(len(heatmaps)):
                        hm = heatmaps[lmcfg.LANDMARK_ID_TO_HEATMAP_ID[l]]
                        # hm = cv2.blur(hm, (9,9))
                        # hm = cv2.medianBlur(hm, 9,9)
                        lms[i,
                            l, :] = np.unravel_index(np.argmax(hm, axis=None),
                                                     hm.shape)[::-1]
        elif hms.shape[1] == 3:
            hms = to_numpy(hms)

            def get_score_plane(h, lm_id, cn):
                v = utils.nn.lmcolors[lm_id, cn]
                hcn = h[cn]
                hcn[hcn < v - 2] = 0
                hcn[hcn > v + 5] = 0
                return hcn

            hms *= 255
            for i in range(len(hms)):
                hm = hms[i]
                for l in landmarks.config.LANDMARKS:
                    lm_score_map = get_score_plane(hm, l, 0) * get_score_plane(
                        hm, l, 1) * get_score_plane(hm, l, 2)
                    lms[i, l, :] = np.unravel_index(
                        np.argmax(lm_score_map, axis=None),
                        lm_score_map.shape)[::-1]
        lm_scale = lmcfg.HEATMAP_SIZE / cfg.INPUT_SIZE
        return lms / lm_scale
コード例 #15
0
def draw_z(z_vecs):

    fy = 1
    width = 10
    z_zoomed = []
    for lvl, _ft in enumerate(to_numpy(z_vecs)):
        # _ft = (_ft-_ft.min())/(_ft.max()-_ft.min())
        vmin = 0 if lvl == 0 else -1

        canvas = np.zeros((int(fy*len(_ft)), width, 3))
        canvas[:int(fy*len(_ft)), :] = color_map(cv2.resize(_ft.reshape(-1,1), dsize=(width, int(fy*len(_ft))),
                                                            interpolation=cv2.INTER_NEAREST), vmin=-1.0, vmax=1.0)
        z_zoomed.append(canvas)
    return make_grid(z_zoomed, nCols=len(z_vecs), padsize=1, padval=0).transpose((1,0,2))
コード例 #16
0
def loss_struct(X,
                X_recon,
                torch_ssim,
                calc_error_maps=False,
                reduction="mean"):
    cs_error_maps = []
    nimgs = len(X)
    errs = torch.zeros(nimgs, requires_grad=True).cuda()
    for i in range(nimgs):
        errs[i] = 1.0 - torch_ssim(X[i].unsqueeze(0), X_recon[i].unsqueeze(0))
        if calc_error_maps:
            cs_error_maps.append(1.0 - to_numpy(torch_ssim.cs_map))
    loss = __reduce(errs, reduction)
    if calc_error_maps:
        return loss, np.vstack(cs_error_maps)
    else:
        return loss, None
コード例 #17
0
def visualize_images(X,
                     X_lm_hm,
                     landmarks=None,
                     show_recon=True,
                     show_landmarks=True,
                     show_heatmaps=False,
                     draw_wireframe=False,
                     smoothing_level=2,
                     heatmap_opacity=0.8,
                     f=1):

    if show_recon:
        disp_X = vis.to_disp_images(X, denorm=True)
    else:
        disp_X = vis.to_disp_images(torch.zeros_like(X), denorm=False)
        heatmap_opacity = 1

    if X_lm_hm is not None:
        if smoothing_level > 0:
            X_lm_hm = smooth_heatmaps(X_lm_hm)
        if smoothing_level > 1:
            X_lm_hm = smooth_heatmaps(X_lm_hm)

    if show_heatmaps:
        pred_heatmaps = to_single_channel_heatmap(to_numpy(X_lm_hm))
        pred_heatmaps = [
            cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_CUBIC)
            for im in pred_heatmaps
        ]
        disp_X = [
            vis.overlay_heatmap(disp_X[i], pred_heatmaps[i], heatmap_opacity)
            for i in range(len(pred_heatmaps))
        ]

    if show_landmarks and landmarks is not None:
        pred_color = (0, 255, 255)
        disp_X = vis.add_landmarks_to_images(disp_X,
                                             landmarks,
                                             color=pred_color,
                                             draw_wireframe=draw_wireframe)

    return disp_X
コード例 #18
0
def heatmaps_to_landmarks(hms, target_size):
    def landmark_id_to_heatmap_id(lm_id):
        return {lm: i for i, lm in enumerate(range(num_landmarks))}[lm_id]

    assert len(hms.shape) == 4
    num_images = hms.shape[0]
    num_landmarks = hms.shape[1]
    heatmap_size = hms.shape[-1]
    lms = np.zeros((num_images, num_landmarks, 2), dtype=int)
    if hms.shape[1] > 3:
        # print(hms.max())
        for i in range(len(hms)):
            heatmaps = to_numpy(hms[i])
            for l in range(len(heatmaps)):
                hm = heatmaps[landmark_id_to_heatmap_id(l)]
                lms[i, l, :] = np.unravel_index(np.argmax(hm, axis=None), hm.shape)[
                    ::-1
                ]
    lm_scale = heatmap_size / target_size
    return lms / lm_scale
コード例 #19
0
 def poses_pred(self):
     try:
         return to_numpy(self.f_parallel[0])
     except:
         return None
コード例 #20
0
 def f_vec(self, i):
     if cfg.WITH_PARALLEL_DISENTANGLEMENT:
         try:
             return to_numpy(self.f_parallel[i])
         except:
             return None
コード例 #21
0
    def _run_batch(self, batch, eval=False, ds=None):
        time_dataloading = time.time() - self.iter_starttime
        time_proc_start = time.time()
        iter_stats = {"time_dataloading": time_dataloading}

        self.saae.zero_grad()
        self.saae.eval()

        input_images = (batch.target_images
                        if batch.target_images is not None else batch.images)

        with torch.set_grad_enabled(self.args.train_encoder):
            z_sample = self.saae.Q(input_images)

        iter_stats.update({"z_recon_mean": z_sample.mean().item()})

        #######################
        # Reconstruction phase
        #######################
        with torch.set_grad_enabled(self.args.train_encoder and not eval):
            X_recon = self.saae.P(z_sample)

        # calculate reconstruction error for debugging and reporting
        with torch.no_grad():
            iter_stats["loss_recon"] = aae_training.loss_recon(
                batch.images, X_recon)

        #######################
        # Landmark predictions
        #######################
        train_lmhead = not eval
        lm_preds_max = None
        with torch.set_grad_enabled(train_lmhead):
            self.saae.LMH.train(train_lmhead)
            X_lm_hm = self.saae.LMH(self.saae.P)
            if batch.lm_heatmaps is not None:
                loss_lms = F.mse_loss(batch.lm_heatmaps, X_lm_hm) * 100 * 3
                iter_stats.update({"loss_lms": loss_lms.item()})

            if eval or self._is_printout_iter(eval):
                # expensive, so only calculate when every N iterations
                # X_lm_hm = lmutils.decode_heatmap_blob(X_lm_hm)
                X_lm_hm = lmutils.smooth_heatmaps(X_lm_hm)
                lm_preds_max = self.saae.heatmaps_to_landmarks(X_lm_hm)

            if eval or self._is_printout_iter(eval):
                lm_gt = to_numpy(batch.landmarks)
                nmes = lmutils.calc_landmark_nme(
                    lm_gt,
                    lm_preds_max,
                    ocular_norm=self.args.ocular_norm,
                    image_size=self.args.input_size,
                )
                # nccs = lmutils.calc_landmark_ncc(batch.images, X_recon, lm_gt)
                iter_stats.update({"nmes": nmes})

        if train_lmhead:
            # if self.args.train_encoder:
            #     loss_lms = loss_lms * 80.0
            loss_lms.backward()
            self.optimizer_lm_head.step()
            if self.args.train_encoder:
                self.optimizer_E.step()
                # self.optimizer_G.step()

        # statistics
        iter_stats.update({
            "epoch": self.epoch,
            "timestamp": time.time(),
            "iter_time": time.time() - self.iter_starttime,
            "time_processing": time.time() - time_proc_start,
            "iter": self.iter_in_epoch,
            "total_iter": self.total_iter,
            "batch_size": len(batch),
        })
        self.iter_starttime = time.time()

        self.epoch_stats.append(iter_stats)

        batch_samples = {
            "batch": batch,
            "X_recon": X_recon,
            "X_lm_hm": X_lm_hm,
            "lm_preds_max": lm_preds_max,
            "ds": ds,
        }

        # print stats every N mini-batches
        if self._is_printout_iter(eval):
            self._print_iter_stats(
                self.epoch_stats[-self._print_interval(eval):])

            out_dir = os.path.join(
                cfg.REPORT_DIR,
                "landmark_predictions",
                self.session_name,
                str(self.epoch + 1),
            )
            io.makedirs(out_dir)

            lmvis.visualize_batch(
                batch.images,
                batch.landmarks,
                X_recon,
                X_lm_hm,
                lm_preds_max,
                self.all_landmarks,
                lm_heatmaps=batch.lm_heatmaps,
                target_images=batch.target_images,
                ds=ds,
                ocular_norm=self.args.ocular_norm,
                clean=False,
                overlay_heatmaps_input=False,
                overlay_heatmaps_recon=False,
                f=1.0,
                wait=self.wait,
                skeleton=self.skeleton,
            )
        return batch_samples
コード例 #22
0
 def z_vecs_pre(self):
     return [to_numpy(self.z_pre)]
コード例 #23
0
    def __train_disenglement_parallel(self, z, Y=None, train=True):
        iter_stats = {}

        self.E.train(train)
        self.G.train(train)

        self.optimizer_E.zero_grad()
        self.optimizer_G.zero_grad()

        #
        # Autoencoding phase
        #

        fts = self.E(z)
        fp, fi, fs, fe = fts

        z_recon = self.G(fp, fi, fs, fe)

        loss_z_recon = F.l1_loss(z, z_recon) * cfg.W_Z_RECON
        if not cfg.WITH_Z_RECON_LOSS:
            loss_z_recon *= 0

        #
        # Info min/max phase
        #

        loss_I = loss_z_recon
        loss_G = torch.zeros(1, requires_grad=True).cuda()

        def calc_err(outputs, target):
            return np.abs(np.rad2deg(F.l1_loss(outputs, target, reduction='none').detach().cpu().numpy().mean(axis=0)))

        def cosine_loss(outputs, targets):
            return (1 - F.cosine_similarity(outputs, targets, dim=1)).mean()

        if Y[3] is not None and Y[3].sum() > 0:  # Has expression -> AffectNet
            available_factors = [3,3,3]
            if cfg.WITH_POSE:
                available_factors = [0] + available_factors
        elif Y[2][1] is not None:  # has vids -> VoxCeleb
            available_factors = [2]
        elif Y[1] is not None:  # Has identities
            available_factors = [1,1,1]
            if cfg.WITH_POSE:
                available_factors = [0] + available_factors
        elif Y[0] is not None: # Any dataset with pose
            available_factors = [0,1,3]

        lvl = available_factors[self.iter % len(available_factors)]

        name = self.factors[lvl]
        try:
            y = Y[lvl]
        except TypeError:
            y = None

        # if y is not None and name != 'shape':
        def calc_feature_loss(name, y_f, y, show_triplets=False, wnd_title=None):
            if name == 'id' or name == 'shape' or name == 'expression':
                display_images = None
                if show_triplets:
                    display_images = self.images
                loss_I_f, err_f = calc_triplet_loss(y_f, y, return_acc=True, images=display_images, feature_name=name,
                                                    wnd_title=wnd_title)
                if name == 'expression':
                    loss_I_f *= 2.0
            elif name == 'pose':
                # loss_I_f, err_f = F.l1_loss(y_f, y), calc_err(y_f, y)
                loss_I_f, err_f = F.mse_loss(y_f, y)*1, calc_err(y_f, y)
                # loss_I_f, err_f = cosine_loss(y_f, y), calc_err(y_f, y)
            else:
                raise ValueError("Unknown feature name!")
            return loss_I_f, err_f


        if y is not None and cfg.WITH_FEATURE_LOSS:

            show_triplets = (self.iter + 1) % self.print_interval  == 0

            y_f = fts[lvl]
            loss_I_f, err_f = calc_feature_loss(name, y_f, y, show_triplets=show_triplets)

            loss_I += cfg.W_FEAT * loss_I_f

            iter_stats[name+'_loss_f'] = loss_I_f.item()
            iter_stats[name+'_err_f'] = np.mean(err_f)

            # train expression classifier
            if name == 'expression':
                self.znet.zero_grad()
                emotion_labels = y[:,0].long()
                clprobs = self.znet(y_f.detach())  # train only znet
                # clprobs = self.znet(y_f)  # train enoder and znet
                # loss_cls = self.cross_entropy_loss(clprobs, emotion_labels)
                loss_cls = self.weighted_CE_loss(clprobs, emotion_labels)

                acc_cls = calc_acc(clprobs, emotion_labels)
                if train:
                    loss_cls.backward(retain_graph=False)
                self.optimizer_znet.step()
                iter_stats['loss_cls'] = loss_cls.item()
                iter_stats['acc_cls'] = acc_cls
                iter_stats['expression_y_probs'] = to_numpy(clprobs)
                iter_stats['expression_y'] = to_numpy(y)


        # cycle loss
        # other_levels = [0,1,2,3]
        # other_levels.remove(lvl)
        # shuffle_lvl = np.random.permutation(other_levels)[0]
        shuffle_lvl = lvl
        # print("shuffling level {}...".format(shuffle_lvl))
        if cfg.WITH_DISENT_CYCLE_LOSS:
            # z_random = torch.rand_like(z).cuda()
            # fts_random = self.E(z_random)

            # create modified feature vectors
            fts[0] = fts[0].detach()
            fts[1] = fts[1].detach()
            fts[2] = fts[2].detach()
            fts[3] = fts[3].detach()
            fts_mod = fts.copy()
            shuffled_ids = torch.randperm(len(fts[shuffle_lvl]))
            y_mod = None
            if y is not None:
                if name == 'shape':
                    y_mod = [y[0][shuffled_ids], y[1][shuffled_ids]]
                else:
                    y_mod = y[shuffled_ids]

            fts_mod[shuffle_lvl] = fts[shuffle_lvl][shuffled_ids]

            # predict full cycle
            z_random_mod = self.G(*fts_mod)
            X_random_mod = self.P(z_random_mod)[:,:3]
            z_random_mod_recon = self.Q(X_random_mod)
            fts2 = self.E(z_random_mod_recon)

            # recon error in unmodified part
            # h = torch.cat([fts_mod[i] for i in range(len(fts_mod)) if i != lvl], dim=1)
            # h2 = torch.cat([fts2[i] for i in range(len(fts2)) if i != lvl], dim=1)
            # l1_err_h = torch.abs(h - h2).mean(dim=1)
            # l1_err_h = torch.abs(torch.cat(fts_mod, dim=1) - torch.cat(fts2, dim=1)).mean(dim=1)

            # recon error in modified part
            # l1_err_f = np.rad2deg(to_numpy(torch.abs(fts_mod[lvl] - fts2[lvl]).mean(dim=1)))

            # recon error in entire vector
            l1_err = torch.abs(torch.cat(fts_mod, dim=1)[:,3:] - torch.cat(fts2, dim=1)[:,3:]).mean(dim=1)
            loss_dis_cycle = F.l1_loss(torch.cat(fts_mod, dim=1)[:,3:], torch.cat(fts2, dim=1)[:,3:]) * cfg.W_CYCLE
            iter_stats['loss_dis_cycle'] = loss_dis_cycle.item()

            loss_I += loss_dis_cycle

            # cycle augmentation loss
            if cfg.WITH_AUGMENTATION_LOSS and y_mod is not None:
                y_f_2 = fts2[lvl]
                loss_I_f_2, err_f_2 = calc_feature_loss(name, y_f_2, y_mod, show_triplets=show_triplets, wnd_title='aug')
                loss_I += loss_I_f_2 * cfg.W_AUG
                iter_stats[name+'_loss_f_2'] = loss_I_f_2.item()
                iter_stats[name+'_err_f_2'] = np.mean(err_f_2)

            #
            # Adversarial loss of modified generations
            #

            GAN = False
            if GAN and train:
                eps = 0.00001

                # #######################
                # # GAN discriminator phase
                # #######################
                update_D = False
                if update_D:
                    self.D.zero_grad()
                    err_real = self.D(self.images)
                    err_fake = self.D(X_random_mod.detach())
                    # err_fake = self.D(X_z_recon.detach())
                    loss_D = -torch.mean(torch.log(err_real + eps) + torch.log(1.0 - err_fake + eps)) * 0.1
                    loss_D.backward()
                    self.optimizer_D.step()
                    iter_stats.update({'loss_D': loss_D.item()})

                #######################
                # Generator loss
                #######################
                self.D.zero_grad()
                err_fake = self.D(X_random_mod)
                # err_fake = self.D(X_z_recon)
                loss_G += -torch.mean(torch.log(err_fake + eps))

                iter_stats.update({'loss_G': loss_G.item()})
                # iter_stats.update({'err_real': err_real.mean().item(), 'err_fake': loss_G.mean().item()})

            # debug visualization
            show = True
            if show:
                if (self.iter+1) % self.print_interval in [0,1]:
                    if Y[3] is None:
                        emotion_gt = np.zeros(len(z), dtype=int)
                        emotion_gt_mod = np.zeros(len(z), dtype=int)
                    else:
                        emotion_gt = Y[3][:,0].long()
                        emotion_gt_mod = Y[3][shuffled_ids,0].long()
                    with torch.no_grad():
                        self.znet.eval()
                        self.G.eval()
                        emotion_preds = torch.max(self.znet(fe.detach()), 1)[1]
                        emotion_mod = torch.max(self.znet(fts_mod[3].detach()), 1)[1]
                        emotion_mod_pred = torch.max(self.znet(fts2[3].detach()), 1)[1]
                        X_recon = self.P(z)[:,:3]
                        X_z_recon = self.P(z_recon)[:,:3]
                        X_random_mod_recon = self.P(self.G(*fts2))[:,:3]
                        self.znet.train(train)
                        self.G.train(train)
                        X_recon_errs = 255.0 * torch.abs(self.images - X_recon).reshape(len(self.images), -1).mean(dim=1)
                        X_z_recon_errs = 255.0 * torch.abs(self.images - X_z_recon).reshape(len(self.images), -1).mean(dim=1)

                    nimgs = 8

                    disp_input = vis.add_pose_to_images(ds_utils.denormalized(self.images)[:nimgs], Y[0], color=(0, 0, 1.0))
                    if name == 'expression':
                        disp_input = vis.add_emotion_to_images(disp_input, to_numpy(emotion_gt))
                    elif name == 'id':
                        disp_input = vis.add_id_to_images(disp_input, to_numpy(Y[1]))

                    disp_recon = vis.add_pose_to_images(ds_utils.denormalized(X_recon)[:nimgs], fts[0])
                    disp_recon = vis.add_error_to_images(disp_recon, errors=X_recon_errs, format_string='{:.1f}')

                    disp_z_recon = vis.add_pose_to_images(ds_utils.denormalized(X_z_recon)[:nimgs], fts[0])
                    disp_z_recon = vis.add_emotion_to_images(disp_z_recon, to_numpy(emotion_preds),
                                                             gt_emotions=to_numpy(emotion_gt) if name=='expression' else None)
                    disp_z_recon = vis.add_error_to_images(disp_z_recon, errors=X_z_recon_errs, format_string='{:.1f}')

                    disp_input_shuffle = vis.add_pose_to_images(ds_utils.denormalized(self.images[shuffled_ids])[:nimgs], fts[0][shuffled_ids])
                    disp_input_shuffle = vis.add_emotion_to_images(disp_input_shuffle, to_numpy(emotion_gt_mod))
                    if name == 'id':
                        disp_input_shuffle = vis.add_id_to_images(disp_input_shuffle, to_numpy(Y[1][shuffled_ids]))

                    disp_recon_shuffle = vis.add_pose_to_images(ds_utils.denormalized(X_random_mod)[:nimgs], fts_mod[0], color=(0, 0, 1.0))
                    disp_recon_shuffle = vis.add_emotion_to_images(disp_recon_shuffle, to_numpy(emotion_mod))

                    disp_cycle = vis.add_pose_to_images(ds_utils.denormalized(X_random_mod_recon)[:nimgs], fts2[0])
                    disp_cycle = vis.add_emotion_to_images(disp_cycle, to_numpy(emotion_mod_pred))
                    disp_cycle = vis.add_error_to_images(disp_cycle, errors=l1_err, format_string='{:.3f}',
                                                         size=0.6, thickness=2, vmin=0, vmax=0.1)

                    rows = [
                        # original input images
                        vis.make_grid(disp_input, nCols=nimgs),

                        # reconstructions without disentanglement
                        vis.make_grid(disp_recon, nCols=nimgs),

                        # reconstructions with disentanglement
                        vis.make_grid(disp_z_recon, nCols=nimgs),

                        # source for feature transfer
                        vis.make_grid(disp_input_shuffle, nCols=nimgs),

                        # reconstructions with modified feature vector (direkt)
                        vis.make_grid(disp_recon_shuffle, nCols=nimgs),

                        # reconstructions with modified feature vector (1 iters)
                        vis.make_grid(disp_cycle, nCols=nimgs)
                    ]
                    f = 1.0 / cfg.INPUT_SCALE_FACTOR
                    disp_img = vis.make_grid(rows, nCols=1, normalize=False, fx=f, fy=f)

                    wnd_title = name
                    if self.current_dataset is not None:
                        wnd_title += ' ' + self.current_dataset.__class__.__name__
                    cv2.imshow(wnd_title, cv2.cvtColor(disp_img, cv2.COLOR_RGB2BGR))
                    cv2.waitKey(10)

        loss_I *= cfg.W_DISENT

        iter_stats['loss_disent'] = loss_I.item()

        if train:
            loss_I.backward(retain_graph=True)

        return z_recon, iter_stats, loss_G[0]
コード例 #24
0
ファイル: lmutils.py プロジェクト: PaperID8601/3FR
def visualize_batch(images,
                    landmarks,
                    X_recon,
                    X_lm_hm,
                    lm_preds_max,
                    lm_heatmaps=None,
                    images_mod=None,
                    lm_preds_cnn=None,
                    ds=None,
                    wait=0,
                    ssim_maps=None,
                    landmarks_to_draw=lmcfg.ALL_LANDMARKS,
                    ocular_norm='outer',
                    horizontal=False,
                    f=1.0,
                    overlay_heatmaps_input=False,
                    overlay_heatmaps_recon=False,
                    clean=False):

    gt_color = (0, 255, 0)
    pred_color = (0, 0, 255)

    nimgs = min(10, len(images))
    images = nn.atleast4d(images)[:nimgs]
    nme_per_lm = None
    if landmarks is None:
        # print('num landmarks', lmcfg.NUM_LANDMARKS)
        lm_gt = np.zeros((nimgs, lmcfg.NUM_LANDMARKS, 2))
    else:
        lm_gt = nn.atleast3d(to_numpy(landmarks))[:nimgs]
        nme_per_lm = calc_landmark_nme(lm_gt,
                                       lm_preds_max[:nimgs],
                                       ocular_norm=ocular_norm)
        lm_ssim_errs = 1 - calc_landmark_ssim_score(images, X_recon[:nimgs],
                                                    lm_gt)

    lm_confs = None
    # show landmark heatmaps
    pred_heatmaps = None
    if X_lm_hm is not None:
        pred_heatmaps = to_single_channel_heatmap(to_numpy(X_lm_hm[:nimgs]))
        pred_heatmaps = [
            cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
            for im in pred_heatmaps
        ]
        gt_heatmaps = None
        if lm_heatmaps is not None:
            gt_heatmaps = to_single_channel_heatmap(
                to_numpy(lm_heatmaps[:nimgs]))
            gt_heatmaps = np.array([
                cv2.resize(im,
                           None,
                           fx=f,
                           fy=f,
                           interpolation=cv2.INTER_NEAREST)
                for im in gt_heatmaps
            ])
        show_landmark_heatmaps(pred_heatmaps, gt_heatmaps, nimgs, f=1)
        lm_confs = to_numpy(X_lm_hm).reshape(X_lm_hm.shape[0],
                                             X_lm_hm.shape[1], -1).max(axis=2)

    # resize images for display and scale landmarks accordingly
    lm_preds_max = lm_preds_max[:nimgs] * f
    if lm_preds_cnn is not None:
        lm_preds_cnn = lm_preds_cnn[:nimgs] * f
    lm_gt *= f

    input_images = vis._to_disp_images(images[:nimgs], denorm=True)
    if images_mod is not None:
        disp_images = vis._to_disp_images(images_mod[:nimgs], denorm=True)
    else:
        disp_images = vis._to_disp_images(images[:nimgs], denorm=True)
    disp_images = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in disp_images
    ]

    recon_images = vis._to_disp_images(X_recon[:nimgs], denorm=True)
    disp_X_recon = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in recon_images.copy()
    ]

    # overlay landmarks on input images
    if pred_heatmaps is not None and overlay_heatmaps_input:
        disp_images = [
            vis.overlay_heatmap(disp_images[i], pred_heatmaps[i])
            for i in range(len(pred_heatmaps))
        ]
    if pred_heatmaps is not None and overlay_heatmaps_recon:
        disp_X_recon = [
            vis.overlay_heatmap(disp_X_recon[i], pred_heatmaps[i])
            for i in range(len(pred_heatmaps))
        ]

    #
    # Show input images
    #
    disp_images = vis.add_landmarks_to_images(disp_images,
                                              lm_gt[:nimgs],
                                              color=gt_color)
    disp_images = vis.add_landmarks_to_images(disp_images,
                                              lm_preds_max[:nimgs],
                                              lm_errs=nme_per_lm,
                                              color=pred_color,
                                              draw_wireframe=False,
                                              gt_landmarks=lm_gt,
                                              draw_gt_offsets=True)

    # disp_images = vis.add_landmarks_to_images(disp_images, lm_gt[:nimgs], color=(1,1,1), radius=1,
    #                                           draw_dots=True, draw_wireframe=True, landmarks_to_draw=landmarks_to_draw)
    # disp_images = vis.add_landmarks_to_images(disp_images, lm_preds_max[:nimgs], lm_errs=nme_per_lm,
    #                                           color=(1.0, 0.0, 0.0),
    #                                           draw_dots=True, draw_wireframe=True, radius=1,
    #                                           gt_landmarks=lm_gt, draw_gt_offsets=False,
    #                                           landmarks_to_draw=landmarks_to_draw)

    #
    # Show reconstructions
    #
    X_recon_errs = 255.0 * torch.abs(images - X_recon[:nimgs]).reshape(
        len(images), -1).mean(dim=1)
    if not clean:
        disp_X_recon = vis.add_error_to_images(disp_X_recon[:nimgs],
                                               errors=X_recon_errs,
                                               format_string='{:>4.1f}')

    # modes of heatmaps
    # disp_X_recon = [overlay_heatmap(disp_X_recon[i], pred_heatmaps[i]) for i in range(len(pred_heatmaps))]
    if not clean:
        lm_errs_max = calc_landmark_nme_per_img(
            lm_gt,
            lm_preds_max,
            ocular_norm=ocular_norm,
            landmarks_to_eval=lmcfg.LANDMARKS_NO_OUTLINE)
        lm_errs_max_outline = calc_landmark_nme_per_img(
            lm_gt,
            lm_preds_max,
            ocular_norm=ocular_norm,
            landmarks_to_eval=lmcfg.LANDMARKS_ONLY_OUTLINE)
        lm_errs_max_all = calc_landmark_nme_per_img(
            lm_gt,
            lm_preds_max,
            ocular_norm=ocular_norm,
            landmarks_to_eval=lmcfg.ALL_LANDMARKS)
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               lm_errs_max,
                                               loc='br-2',
                                               format_string='{:>5.2f}',
                                               vmax=15)
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               lm_errs_max_outline,
                                               loc='br-1',
                                               format_string='{:>5.2f}',
                                               vmax=15)
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               lm_errs_max_all,
                                               loc='br',
                                               format_string='{:>5.2f}',
                                               vmax=15)
    disp_X_recon = vis.add_landmarks_to_images(disp_X_recon,
                                               lm_gt,
                                               color=gt_color,
                                               draw_wireframe=True)

    # disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, lm_preds_max[:nimgs],
    #                                            color=pred_color, draw_wireframe=False,
    #                                            lm_errs=nme_per_lm, lm_confs=lm_confs,
    #                                            lm_rec_errs=lm_ssim_errs, gt_landmarks=lm_gt,
    #                                            draw_gt_offsets=True, draw_dots=True)

    disp_X_recon = vis.add_landmarks_to_images(disp_X_recon,
                                               lm_preds_max[:nimgs],
                                               color=pred_color,
                                               draw_wireframe=True,
                                               gt_landmarks=lm_gt,
                                               draw_gt_offsets=True,
                                               lm_errs=nme_per_lm,
                                               draw_dots=True,
                                               radius=2)

    def add_confs(disp_X_recon, lmids, loc):
        means = lm_confs[:, lmids].mean(axis=1)
        colors = vis.color_map(to_numpy(1 - means),
                               cmap=plt.cm.jet,
                               vmin=0.0,
                               vmax=0.4)
        return vis.add_error_to_images(disp_X_recon,
                                       means,
                                       loc=loc,
                                       format_string='{:>4.2f}',
                                       colors=colors)

    # disp_X_recon = add_confs(disp_X_recon, lmcfg.LANDMARKS_NO_OUTLINE, 'bm-2')
    # disp_X_recon = add_confs(disp_X_recon, lmcfg.LANDMARKS_ONLY_OUTLINE, 'bm-1')
    # disp_X_recon = add_confs(disp_X_recon, lmcfg.ALL_LANDMARKS, 'bm')

    # print ssim errors
    ssim = np.zeros(nimgs)
    for i in range(nimgs):
        ssim[i] = compare_ssim(input_images[i],
                               recon_images[i],
                               data_range=1.0,
                               multichannel=True)
    if not clean:
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               1 - ssim,
                                               loc='bl-1',
                                               format_string='{:>4.2f}',
                                               vmax=0.8,
                                               vmin=0.2)
    # print ssim torch errors
    if ssim_maps is not None and not clean:
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               ssim_maps.reshape(
                                                   len(ssim_maps),
                                                   -1).mean(axis=1),
                                               loc='bl-2',
                                               format_string='{:>4.2f}',
                                               vmin=0.0,
                                               vmax=0.4)

    rows = [vis.make_grid(disp_images, nCols=nimgs, normalize=False)]
    rows.append(vis.make_grid(disp_X_recon, nCols=nimgs))

    if ssim_maps is not None:
        disp_ssim_maps = to_numpy(
            ds_utils.denormalized(ssim_maps)[:nimgs].transpose(0, 2, 3, 1))
        for i in range(len(disp_ssim_maps)):
            disp_ssim_maps[i] = vis.color_map(disp_ssim_maps[i].mean(axis=2),
                                              vmin=0.0,
                                              vmax=2.0)
        grid_ssim_maps = vis.make_grid(disp_ssim_maps, nCols=nimgs, fx=f, fy=f)
        cv2.imshow('ssim errors',
                   cv2.cvtColor(grid_ssim_maps, cv2.COLOR_RGB2BGR))

    if horizontal:
        assert (nimgs == 1)
        disp_rows = vis.make_grid(rows, nCols=2)
    else:
        disp_rows = vis.make_grid(rows, nCols=1)
    wnd_title = 'Predicted Landmarks '
    if ds is not None:
        wnd_title += ds.__class__.__name__
    cv2.imshow(wnd_title, cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR))
    cv2.waitKey(wait)
コード例 #25
0
 def z_vecs(self):
     return [to_numpy(self.z)]
コード例 #26
0
ファイル: lmutils.py プロジェクト: PaperID8601/3FR
def visualize_batch_CVPR(images,
                         landmarks,
                         X_recon,
                         X_lm_hm,
                         lm_preds,
                         lm_heatmaps=None,
                         ds=None,
                         wait=0,
                         horizontal=False,
                         f=1.0,
                         radius=2):

    gt_color = (0, 255, 0)
    pred_color = (0, 255, 255)

    nimgs = min(10, len(images))
    images = nn.atleast4d(images)[:nimgs]
    if landmarks is None:
        print('num landmarks', lmcfg.NUM_LANDMARKS)
        lm_gt = np.zeros((nimgs, lmcfg.NUM_LANDMARKS, 2))
    else:
        lm_gt = nn.atleast3d(to_numpy(landmarks))[:nimgs]

    # show landmark heatmaps
    pred_heatmaps = None
    if X_lm_hm is not None:
        pred_heatmaps = to_single_channel_heatmap(to_numpy(X_lm_hm[:nimgs]))
        pred_heatmaps = [
            cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
            for im in pred_heatmaps
        ]
        gt_heatmaps = None
        if lm_heatmaps is not None:
            gt_heatmaps = to_single_channel_heatmap(
                to_numpy(lm_heatmaps[:nimgs]))
            gt_heatmaps = np.array([
                cv2.resize(im,
                           None,
                           fx=f,
                           fy=f,
                           interpolation=cv2.INTER_NEAREST)
                for im in gt_heatmaps
            ])
        show_landmark_heatmaps(pred_heatmaps, gt_heatmaps, nimgs, f=1)
        lm_confs = to_numpy(X_lm_hm).reshape(X_lm_hm.shape[0],
                                             X_lm_hm.shape[1], -1).max(axis=2)

    # resize images for display and scale landmarks accordingly
    lm_preds = lm_preds[:nimgs] * f
    lm_gt *= f

    rows = []

    disp_images = vis._to_disp_images(images[:nimgs], denorm=True)
    disp_images = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in disp_images
    ]
    rows.append(vis.make_grid(disp_images, nCols=nimgs, normalize=False))

    recon_images = vis._to_disp_images(X_recon[:nimgs], denorm=True)
    disp_X_recon = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in recon_images.copy()
    ]
    rows.append(vis.make_grid(disp_X_recon, nCols=nimgs))

    # recon_images = vis._to_disp_images(X_recon[:nimgs], denorm=True)
    disp_X_recon_pred = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in recon_images.copy()
    ]
    disp_X_recon_pred = vis.add_landmarks_to_images(disp_X_recon_pred,
                                                    lm_preds,
                                                    color=pred_color,
                                                    radius=radius)
    rows.append(vis.make_grid(disp_X_recon_pred, nCols=nimgs))

    disp_X_recon_gt = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in recon_images.copy()
    ]
    disp_X_recon_gt = vis.add_landmarks_to_images(disp_X_recon_gt,
                                                  lm_gt,
                                                  color=gt_color,
                                                  radius=radius)
    rows.append(vis.make_grid(disp_X_recon_gt, nCols=nimgs))

    # overlay landmarks on images
    disp_X_recon_hm = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in recon_images.copy()
    ]
    disp_X_recon_hm = [
        vis.overlay_heatmap(disp_X_recon_hm[i], pred_heatmaps[i])
        for i in range(len(pred_heatmaps))
    ]
    rows.append(vis.make_grid(disp_X_recon_hm, nCols=nimgs))

    # input images with prediction (and ground truth)
    disp_images_pred = vis._to_disp_images(images[:nimgs], denorm=True)
    disp_images_pred = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in disp_images_pred
    ]
    # disp_images_pred = vis.add_landmarks_to_images(disp_images_pred, lm_gt, color=gt_color, radius=radius)
    disp_images_pred = vis.add_landmarks_to_images(disp_images_pred,
                                                   lm_preds,
                                                   color=pred_color,
                                                   radius=radius)
    rows.append(vis.make_grid(disp_images_pred, nCols=nimgs))

    if horizontal:
        assert (nimgs == 1)
        disp_rows = vis.make_grid(rows, nCols=2)
    else:
        disp_rows = vis.make_grid(rows, nCols=1)
    wnd_title = 'recon errors '
    if ds is not None:
        wnd_title += ds.__class__.__name__
    cv2.imshow(wnd_title, cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR))
    cv2.waitKey(wait)
コード例 #27
0
    def visualize_batch(self,
                        batch,
                        X_recon,
                        ssim_maps,
                        nimgs=8,
                        ds=None,
                        wait=0):

        nimgs = min(nimgs, len(batch))
        train_state_D = self.saae.D.training
        train_state_Q = self.saae.Q.training
        train_state_P = self.saae.P.training
        self.saae.D.eval()
        self.saae.Q.eval()
        self.saae.P.eval()

        loc_err_gan = "tr"
        text_size_errors = 0.65

        input_images = vis.reconstruct_images(batch.images[:nimgs])
        show_filenames = batch.filenames[:nimgs]
        target_images = (batch.target_images
                         if batch.target_images is not None else batch.images)
        disp_images = vis.reconstruct_images(target_images[:nimgs])

        # draw GAN score
        if self.args.with_gan:
            with torch.no_grad():
                err_gan_inputs = self.saae.D(batch.images[:nimgs])
            disp_images = vis.add_error_to_images(
                disp_images,
                errors=1 - err_gan_inputs,
                loc=loc_err_gan,
                format_string="{:>5.2f}",
                vmax=1.0,
            )

        # disp_images = vis.add_landmarks_to_images(disp_images, batch.landmarks[:nimgs], color=(0,1,0), radius=1,
        #                                           draw_wireframe=False)
        rows = [vis.make_grid(disp_images, nCols=nimgs, normalize=False)]

        recon_images = vis.reconstruct_images(X_recon[:nimgs])
        disp_X_recon = recon_images.copy()

        print_stats = True
        if print_stats:
            # lm_ssim_errs = None
            # if batch.landmarks is not None:
            #     lm_recon_errs = lmutils.calc_landmark_recon_error(batch.images[:nimgs], X_recon[:nimgs], batch.landmarks[:nimgs], reduction='none')
            #     disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_recon_errs, size=text_size_errors, loc='bm',
            #                                            format_string='({:>3.1f})', vmin=0, vmax=10)
            #     lm_ssim_errs = lmutils.calc_landmark_ssim_error(batch.images[:nimgs], X_recon[:nimgs], batch.landmarks[:nimgs])
            #     disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_ssim_errs.mean(axis=1), size=text_size_errors, loc='bm-1',
            #                                            format_string='({:>3.2f})', vmin=0.2, vmax=0.8)

            X_recon_errs = 255.0 * torch.abs(batch.images - X_recon).reshape(
                len(batch.images), -1).mean(dim=1)
            # disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, batch.landmarks[:nimgs], radius=1, color=None,
            #                                            lm_errs=lm_ssim_errs, draw_wireframe=False)
            disp_X_recon = vis.add_error_to_images(
                disp_X_recon[:nimgs],
                errors=X_recon_errs,
                size=text_size_errors,
                format_string="{:>4.1f}",
            )
            if self.args.with_gan:
                with torch.no_grad():
                    err_gan = self.saae.D(X_recon[:nimgs])
                disp_X_recon = vis.add_error_to_images(
                    disp_X_recon,
                    errors=1 - err_gan,
                    loc=loc_err_gan,
                    format_string="{:>5.2f}",
                    vmax=1.0,
                )

            ssim = np.zeros(nimgs)
            for i in range(nimgs):
                data_range = 255.0 if input_images[0].dtype == np.uint8 else 1.0
                ssim[i] = compare_ssim(
                    input_images[i],
                    recon_images[i],
                    data_range=data_range,
                    multichannel=True,
                )
            disp_X_recon = vis.add_error_to_images(
                disp_X_recon,
                1 - ssim,
                loc="bl-1",
                size=text_size_errors,
                format_string="{:>4.2f}",
                vmin=0.2,
                vmax=0.8,
            )

            if ssim_maps is not None:
                disp_X_recon = vis.add_error_to_images(
                    disp_X_recon,
                    ssim_maps.reshape(len(ssim_maps), -1).mean(axis=1),
                    size=text_size_errors,
                    loc="bl-2",
                    format_string="{:>4.2f}",
                    vmin=0.0,
                    vmax=0.4,
                )

        rows.append(vis.make_grid(disp_X_recon, nCols=nimgs))

        if ssim_maps is not None:
            disp_ssim_maps = to_numpy(
                nn.denormalized(ssim_maps)[:nimgs].transpose(0, 2, 3, 1))
            if disp_ssim_maps.shape[3] == 1:
                disp_ssim_maps = disp_ssim_maps.repeat(3, axis=3)
            for i in range(len(disp_ssim_maps)):
                disp_ssim_maps[i] = vis.color_map(
                    disp_ssim_maps[i].mean(axis=2), vmin=0.0, vmax=2.0)
            grid_ssim_maps = vis.make_grid(disp_ssim_maps, nCols=nimgs)
            cv2.imwrite("ssim errors.jpg",
                        cv2.cvtColor(grid_ssim_maps, cv2.COLOR_RGB2BGR))

        self.saae.D.train(train_state_D)
        self.saae.Q.train(train_state_Q)
        self.saae.P.train(train_state_P)

        f = 1
        disp_rows = vis.make_grid(rows, nCols=1, normalize=False, fx=f, fy=f)
        wnd_title = "recon errors "
        if ds is not None:
            wnd_title += ds.__class__.__name__
        cv2.imwrite(wnd_title + ".jpg",
                    cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR))
        cv2.waitKey(wait)
コード例 #28
0
def get_landmark_confs(X_lm_hm):
    return np.clip(
        to_numpy(X_lm_hm).reshape(X_lm_hm.shape[0], X_lm_hm.shape[1], -1).max(axis=2),
        a_min=0,
        a_max=1,
    )
コード例 #29
0
def eval_affectnet(net,
                   n=2000,
                   feat_type=3,
                   eval_notf=True,
                   only_good_images=True,
                   show=False):

    print("Evaluating AffectNet...")

    batch_size = 20 if show else 100
    dataset = affectnet.AffectNet(train=False,
                                  max_samples=n,
                                  deterministic=True,
                                  use_cache=True)
    dataloader = td.DataLoader(dataset,
                               batch_size=batch_size,
                               shuffle=False,
                               num_workers=6)
    print(dataset)

    labels = []
    clprobs = []

    for iter, data in enumerate(dataloader):
        batch = nn.Batch(data)

        with torch.no_grad():
            X_recon = net(batch.images, Y=None)[:, :3]
        if show:
            nimgs = 25
            f = 1.0
            img = saae.draw_results(batch.images,
                                    X_recon,
                                    net.z_vecs(),
                                    emotions=batch.emotions,
                                    emotions_pred=net.emotions_pred(),
                                    fx=f,
                                    fy=f,
                                    ncols=10)
            cv2.imshow('reconst', cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            cv2.waitKey()

        clprobs.append(nn.to_numpy(net.emotion_probs))
        labels.append(nn.to_numpy(batch.emotions))

        if (iter % 10) == 0:
            print(iter)

    clprobs = np.vstack(clprobs)
    labels = np.concatenate(labels).astype(int)

    accuracy, auc, auc_micro, conf_matrix = evaluate(clprobs, labels)
    print('Accuracy  F: %2.5f+-%2.5f' % (np.mean(accuracy), np.std(accuracy)))
    log.info("\nAUCs: {} ({})".format(auc, np.mean(list(auc.values()))))
    log.info("\nAUC micro: {} ".format(auc_micro))

    print(conf_matrix)
    vis.plot_confusion_matrix(conf_matrix,
                              classes=affectnet.AffectNet.classes,
                              normalize=True)
    plt.show()
コード例 #30
0
 def reformat(lms):
     lms = to_numpy(lms)
     if len(lms.shape) == 2:
         lms = lms.reshape((1, -1, 2))
     return lms