示例#1
0
文件: fabrec.py 项目: walaa5/Face-X
    def heatmaps_to_landmarks(self, hms):
        lms = np.zeros((len(hms), self.num_landmarks, 2), dtype=int)
        if hms.shape[1] > 3:
            # print(hms.max())
            for i in range(len(hms)):
                heatmaps = to_numpy(hms[i])
                for l in range(len(heatmaps)):
                    hm = heatmaps[self.landmark_id_to_heatmap_id(l)]
                    lms[i, l, :] = np.unravel_index(np.argmax(hm, axis=None),
                                                    hm.shape)[::-1]
        elif hms.shape[1] == 3:
            hms = to_numpy(hms)

            def get_score_plane(h, lm_id, cn):
                v = utils.nn.lmcolors[lm_id, cn]
                hcn = h[cn]
                hcn[hcn < v - 2] = 0
                hcn[hcn > v + 5] = 0
                return hcn

            hms *= 255
            for i in range(len(hms)):
                hm = hms[i]
                for l in landmarks.config.LANDMARKS:
                    lm_score_map = get_score_plane(hm, l, 0) * get_score_plane(
                        hm, l, 1) * get_score_plane(hm, l, 2)
                    lms[i, l, :] = np.unravel_index(
                        np.argmax(lm_score_map, axis=None),
                        lm_score_map.shape)[::-1]
        lm_scale = lmcfg.HEATMAP_SIZE / self.input_size
        return lms / lm_scale
示例#2
0
def _predict_center_crop(net, image, crop_size=544, gpu=True):
    h, w, c = image.shape
    image_probs = torch.zeros((h, w))

    x = (w - crop_size) // 2
    y = (h - crop_size) // 2
    image_crop = image[y:y + crop_size, x:x + crop_size]

    input = _crop_to_tensor(image=image_crop)['image']
    if gpu:
        input = input.cuda()

    with torch.no_grad():
        t = time.time()
        crop_probs = net.forward(input.unsqueeze(0))
        print(f'time forward: {int(1000 * (time.time() - t))}ms')

    show = False
    if show:
        disp_crop = vis.to_disp_image(input, denorm=True)
        fig, ax = plt.subplots(1, 2, sharex=True, sharey=True)
        ax[0].imshow(disp_crop)
        ax[1].imshow(to_numpy(crop_probs[0, 0]),
                     cmap=plt.cm.viridis,
                     vmin=0,
                     vmax=1)
        plt.tight_layout()
        plt.show()

    image_probs[y:y + crop_size,
                x:x + crop_size] = crop_probs.squeeze().squeeze()
    return image, to_numpy(image_probs)
示例#3
0
    def predict_sequential():
        image_probs = torch.zeros((h_pad, w_pad))
        for ix in range(npx):
            for iy in range(npy):
                x = ix * inner_size
                y = iy * inner_size

                crop = image_pad[y:y + s, x:x + s]
                input = _crop_to_tensor(image=crop)['image'].cuda()

                with torch.no_grad():
                    crop_probs = net.forward(input.unsqueeze(0))

                show = False
                if show:
                    disp_crop = vis.to_disp_image(input, denorm=True)
                    fig, ax = plt.subplots(1, 2, sharex=True, sharey=True)
                    ax[0].imshow(disp_crop)
                    ax[1].imshow(to_numpy(crop_probs[0, 0]),
                                 cmap=plt.cm.viridis,
                                 vmin=0,
                                 vmax=1)
                    plt.tight_layout()
                    plt.show()

                image_probs[y+d:y+d+inner_size, x+d:x+d+inner_size] = \
                    crop_probs.squeeze().squeeze()[d:d+inner_size, d:d+inner_size]

        return to_numpy(image_probs)
示例#4
0
def show_segmentation_results(orig_image,
                              recon,
                              preds,
                              gt_mask=None,
                              foreground_mask=None,
                              threshold=0.5):
    """ Show results for one image """

    fig, ax = plt.subplots(2, 3, sharex=True, sharey=True)
    if torch.is_tensor(preds):
        preds = preds.squeeze().squeeze()
    else:
        preds = preds.squeeze()

    if foreground_mask is None:
        foreground_mask = np.ones_like(preds).astype(np.uint8)

    if gt_mask is not None:
        diff_map, _ = difference_map(gt_mask, preds, foreground_mask)
    else:
        diff_map = np.zeros_like(preds)

    if gt_mask is None:
        gt_mask = np.zeros_like(preds).astype(np.uint8)

    pred_mask = to_numpy((preds > threshold).squeeze())
    gt_mask = to_numpy(gt_mask)

    # probs = np.clip(probs, a_min=0, a_max=1)
    imgfname = f'./outputs/results/{modelname}/hrf_{idx + 1:02d}_probs.png'
    io_utils.makedirs(imgfname)
    cv2.imwrite(imgfname, (preds * 255).astype(np.uint8))
    imgfname = f'./outputs/results/{modelname}/hrf_{idx + 1:02d}_diff.png'
    cv2.imwrite(imgfname,
                cv2.cvtColor((diff_map).astype(np.uint8), cv2.COLOR_RGB2BGR))
    imgfname = f'./outputs/results/{modelname}/hrf_{idx + 1:02d}_orig.png'
    cv2.imwrite(imgfname,
                cv2.cvtColor((orig_image).astype(np.uint8), cv2.COLOR_RGB2BGR))

    ax[0, 0].imshow(orig_image)
    ax[0, 1].imshow(gt_mask)
    # ax[0,2].imshow(errors.astype(np.uint8))
    ax[0, 2].imshow(diff_map.astype(np.uint8))

    ax[1, 0].imshow(vis.to_disp_image(recon.squeeze(), denorm=True))
    # ax[1,1].imshow(preds, vmin=-1, vmax=1)
    ax[1, 1].imshow(preds, vmax=1)
    ax[1, 2].imshow(pred_mask.astype(np.uint8))
    plt.tight_layout()
示例#5
0
    def batch_predict():
        image_probs = np.zeros((h_pad, w_pad))
        inputs = []
        for ix in range(npx):
            for iy in range(npy):

                x = ix * inner_size
                y = iy * inner_size

                crop = image_pad[y:y + s, x:x + s]
                input = _crop_to_tensor(image=crop)['image']
                inputs.append(input)

        inputs = torch.stack(inputs)

        with torch.no_grad():
            crop_probs = net.forward(inputs.cuda())

        crop_probs = to_numpy(crop_probs)

        crop_id = 0
        for ix in range(npx):
            for iy in range(npy):
                x = ix * inner_size
                y = iy * inner_size

                image_probs[y+d:y+d+inner_size, x+d:x+d+inner_size] = \
                    crop_probs[crop_id, 0, d:d+inner_size, d:d+inner_size]
                crop_id += 1
        return image_probs
示例#6
0
def overlay_vessels_heatmap(imgs, pred_vessel_hm):
    pred_vessel_hm = to_numpy(pred_vessel_hm)
    disp_X_recon_overlay = [
        vis.overlay_heatmap(imgs[i], pred_vessel_hm[i, 0], 1.0)
        for i in range(len(pred_vessel_hm))
    ]
    return disp_X_recon_overlay
示例#7
0
文件: lmutils.py 项目: oxyai/3FabRec
def calc_landmark_recon_error(X,
                              X_recon,
                              lms,
                              return_maps=False,
                              reduction='mean'):
    assert len(X.shape) == 4
    assert reduction in ['mean', 'none']
    X = to_numpy(X)
    X_recon = to_numpy(X_recon)
    mask = np.zeros((X.shape[0], X.shape[2], X.shape[3]), dtype=np.float32)
    input_size = X.shape[-1]
    radius = input_size * 0.05
    for img_id in range(len(mask)):
        for lm in lms[img_id]:
            cv2.circle(mask[img_id], (int(lm[0]), int(lm[1])),
                       radius=int(radius),
                       color=1,
                       thickness=-1)
    err_maps = np.abs(X - X_recon).mean(axis=1) * 255.0
    masked_err_maps = err_maps * mask

    debug = False
    if debug:
        fig, ax = plt.subplots(1, 3)
        ax[0].imshow(
            vis.to_disp_image(
                (X * mask[:, np.newaxis, :, :].repeat(3, axis=1))[0],
                denorm=True))
        ax[1].imshow(
            vis.to_disp_image(
                (X_recon * mask[:, np.newaxis, :, :].repeat(3, axis=1))[0],
                denorm=True))
        ax[2].imshow(masked_err_maps[0])
        plt.show()

    if reduction == 'mean':
        err = masked_err_maps.sum() / (mask.sum() * 3)
    else:
        # err = masked_err_maps.mean(axis=2).mean(axis=1)
        err = masked_err_maps.sum(axis=2).sum(
            axis=1) / (mask.reshape(len(mask), -1).sum(axis=1) * 3)

    if return_maps:
        return err, masked_err_maps
    else:
        return err
示例#8
0
文件: lmutils.py 项目: oxyai/3FabRec
def smooth_heatmaps(hms):
    assert (len(hms.shape) == 4)
    hms = to_numpy(hms)
    for i in range(hms.shape[0]):
        for l in range(hms.shape[1]):
            hms[i, l] = cv2.blur(hms[i, l], (9, 9),
                                 borderType=cv2.BORDER_CONSTANT)
            # hms[i,l] = cv2.GaussianBlur(hms[i,l], (9,9), sigmaX=9, borderType=cv2.BORDER_CONSTANT)
    return hms
示例#9
0
文件: vis.py 项目: walaa5/Face-X
def add_error_to_images(images, errors, loc='bl', size=0.65, vmin=0., vmax=30.0, thickness=1,
                        format_string='{:.1f}', colors=None):
    new_images = to_disp_images(images)
    if colors is None:
        colors = color_map(to_numpy(errors), cmap=plt.cm.jet, vmin=vmin, vmax=vmax)
        if images[0].dtype == np.uint8:
            colors *= 255
    for disp, err, color in zip(new_images, errors, colors):
        pos = get_pos_in_image(loc, size, disp.shape)
        cv2.putText(disp, format_string.format(err), pos, cv2.FONT_HERSHEY_DUPLEX, size, color, thickness, cv2.LINE_AA)
    return new_images
示例#10
0
 def add_confidences(disp_X_recon, lmids, loc):
     means = lm_confs[:, lmids].mean(axis=1)
     colors = vis.color_map(to_numpy(1 - means),
                            cmap=plt.cm.jet,
                            vmin=0.0,
                            vmax=0.4)
     return vis.add_error_to_images(disp_X_recon,
                                    means,
                                    loc=loc,
                                    format_string='{:>4.2f}',
                                    colors=colors)
示例#11
0
文件: vis.py 项目: walaa5/Face-X
def draw_z(z_vecs):

    fy = 1
    width = 10
    z_zoomed = []
    for lvl, _ft in enumerate(to_numpy(z_vecs)):
        # _ft = (_ft-_ft.min())/(_ft.max()-_ft.min())
        vmin = 0 if lvl == 0 else -1

        canvas = np.zeros((int(fy*len(_ft)), width, 3))
        canvas[:int(fy*len(_ft)), :] = color_map(cv2.resize(_ft.reshape(-1,1), dsize=(width, int(fy*len(_ft))),
                                                            interpolation=cv2.INTER_NEAREST), vmin=-1.0, vmax=1.0)
        z_zoomed.append(canvas)
    return make_grid(z_zoomed, nCols=len(z_vecs), padsize=1, padval=0).transpose((1,0,2))
示例#12
0
def loss_struct(X,
                X_recon,
                torch_ssim,
                calc_error_maps=False,
                reduction='mean'):
    cs_error_maps = []
    nimgs = len(X)
    errs = torch.zeros(nimgs, requires_grad=True).cuda()
    for i in range(nimgs):
        errs[i] = 1.0 - torch_ssim(X[i].unsqueeze(0), X_recon[i].unsqueeze(0))
        if calc_error_maps:
            cs_error_maps.append(1.0 - to_numpy(torch_ssim.cs_map))
    loss = __reduce(errs, reduction)
    if calc_error_maps:
        return loss, np.vstack(cs_error_maps)
    else:
        return loss, None
示例#13
0
文件: lmutils.py 项目: oxyai/3FabRec
def heatmaps_to_landmarks(hms, target_size):
    def landmark_id_to_heatmap_id(lm_id):
        return {lm: i for i, lm in enumerate(range(num_landmarks))}[lm_id]

    assert len(hms.shape) == 4
    num_images = hms.shape[0]
    num_landmarks = hms.shape[1]
    heatmap_size = hms.shape[-1]
    lms = np.zeros((num_images, num_landmarks, 2), dtype=int)
    if hms.shape[1] > 3:
        # print(hms.max())
        for i in range(len(hms)):
            heatmaps = to_numpy(hms[i])
            for l in range(len(heatmaps)):
                hm = heatmaps[landmark_id_to_heatmap_id(l)]
                lms[i, l, :] = np.unravel_index(np.argmax(hm, axis=None),
                                                hm.shape)[::-1]
    lm_scale = heatmap_size / target_size
    return lms / lm_scale
示例#14
0
def visualize_images(X,
                     X_lm_hm,
                     landmarks=None,
                     show_recon=True,
                     show_landmarks=True,
                     show_heatmaps=False,
                     draw_wireframe=False,
                     smoothing_level=2,
                     heatmap_opacity=0.8,
                     f=1):

    if show_recon:
        disp_X = vis.to_disp_images(X, denorm=True)
    else:
        disp_X = vis.to_disp_images(torch.zeros_like(X), denorm=False)
        heatmap_opacity = 1

    if X_lm_hm is not None:
        if smoothing_level > 0:
            X_lm_hm = smooth_heatmaps(X_lm_hm)
        if smoothing_level > 1:
            X_lm_hm = smooth_heatmaps(X_lm_hm)

    if show_heatmaps:
        pred_heatmaps = to_single_channel_heatmap(to_numpy(X_lm_hm))
        pred_heatmaps = [
            cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_CUBIC)
            for im in pred_heatmaps
        ]
        disp_X = [
            vis.overlay_heatmap(disp_X[i], pred_heatmaps[i], heatmap_opacity)
            for i in range(len(pred_heatmaps))
        ]

    if show_landmarks and landmarks is not None:
        pred_color = (0, 255, 255)
        disp_X = vis.add_landmarks_to_images(disp_X,
                                             landmarks,
                                             color=pred_color,
                                             draw_wireframe=draw_wireframe)

    return disp_X
示例#15
0
def draw_results(X_resized,
                 X_recon,
                 levels_z=None,
                 landmarks=None,
                 landmarks_pred=None,
                 cs_errs=None,
                 ncols=15,
                 fx=0.5,
                 fy=0.5,
                 additional_status_text=''):

    clean_images = True
    if clean_images:
        landmarks = None

    nimgs = len(X_resized)
    ncols = min(ncols, nimgs)
    img_size = X_recon.shape[-1]

    disp_X = vis.to_disp_images(X_resized, denorm=True)
    disp_X_recon = vis.to_disp_images(X_recon, denorm=True)

    # reconstruction error in pixels
    l1_dists = 255.0 * to_numpy(
        (X_resized - X_recon).abs().reshape(len(disp_X), -1).mean(dim=1))

    # SSIM errors
    ssim = np.zeros(nimgs)
    for i in range(nimgs):
        ssim[i] = compare_ssim(disp_X[i],
                               disp_X_recon[i],
                               data_range=1.0,
                               multichannel=True)

    landmarks = to_numpy(landmarks)
    cs_errs = to_numpy(cs_errs)

    text_size = img_size / 256
    text_thickness = 2

    #
    # Visualise resized input images and reconstructed images
    #
    if landmarks is not None:
        disp_X = vis.add_landmarks_to_images(
            disp_X,
            landmarks,
            draw_wireframe=False,
            landmarks_to_draw=lmcfg.LANDMARKS_19)
        disp_X_recon = vis.add_landmarks_to_images(
            disp_X_recon,
            landmarks,
            draw_wireframe=False,
            landmarks_to_draw=lmcfg.LANDMARKS_19)

    if landmarks_pred is not None:
        disp_X = vis.add_landmarks_to_images(disp_X,
                                             landmarks_pred,
                                             color=(1, 0, 0))
        disp_X_recon = vis.add_landmarks_to_images(disp_X_recon,
                                                   landmarks_pred,
                                                   color=(1, 0, 0))

    if not clean_images:
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               l1_dists,
                                               format_string='{:.1f}',
                                               size=text_size,
                                               thickness=text_thickness)
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               1 - ssim,
                                               loc='bl-1',
                                               format_string='{:>4.2f}',
                                               vmax=0.8,
                                               vmin=0.2,
                                               size=text_size,
                                               thickness=text_thickness)
        if cs_errs is not None:
            disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                                   cs_errs,
                                                   loc='bl-2',
                                                   format_string='{:>4.2f}',
                                                   vmax=0.0,
                                                   vmin=0.4,
                                                   size=text_size,
                                                   thickness=text_thickness)

    # landmark errors
    lm_errs = np.zeros(1)
    if landmarks is not None:
        try:
            from landmarks import lmutils
            lm_errs = lmutils.calc_landmark_nme_per_img(
                landmarks, landmarks_pred)
            disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                                   lm_errs,
                                                   loc='br',
                                                   format_string='{:>5.2f}',
                                                   vmax=15,
                                                   size=img_size / 256,
                                                   thickness=2)
        except:
            pass

    img_input = vis.make_grid(disp_X, nCols=ncols, normalize=False)
    img_recon = vis.make_grid(disp_X_recon, nCols=ncols, normalize=False)
    img_input = cv2.resize(img_input,
                           None,
                           fx=fx,
                           fy=fy,
                           interpolation=cv2.INTER_CUBIC)
    img_recon = cv2.resize(img_recon,
                           None,
                           fx=fx,
                           fy=fy,
                           interpolation=cv2.INTER_CUBIC)

    img_stack = [img_input, img_recon]

    #
    # Visualise hidden layers
    #
    VIS_HIDDEN = False
    if VIS_HIDDEN:
        img_z = vis.draw_z_vecs(levels_z, size=(img_size, 30), ncols=ncols)
        img_z = cv2.resize(img_z,
                           dsize=(img_input.shape[1], img_z.shape[0]),
                           interpolation=cv2.INTER_NEAREST)
        img_stack.append(img_z)

    cs_errs_mean = np.mean(cs_errs) if cs_errs is not None else np.nan
    status_bar_text = ("l1 recon err: {:.2f}px  "
                       "ssim: {:.3f}({:.3f})  "
                       "lms err: {:2} {}").format(l1_dists.mean(),
                                                  cs_errs_mean,
                                                  1 - ssim.mean(),
                                                  lm_errs.mean(),
                                                  additional_status_text)

    img_status_bar = vis.draw_status_bar(status_bar_text,
                                         status_bar_width=img_input.shape[1],
                                         status_bar_height=30,
                                         dtype=img_input.dtype)
    img_stack.append(img_status_bar)

    return np.vstack(img_stack)
示例#16
0
    def visualize_batch(self,
                        batch,
                        X_recon,
                        ssim_maps,
                        nimgs=8,
                        ds=None,
                        wait=0):

        nimgs = min(nimgs, len(batch))
        train_state_D = self.saae.D.training
        train_state_Q = self.saae.Q.training
        train_state_P = self.saae.P.training
        self.saae.D.eval()
        self.saae.Q.eval()
        self.saae.P.eval()

        loc_err_gan = 'tr'
        text_size_errors = 0.65

        input_images = vis.to_disp_images(batch.images[:nimgs], denorm=True)
        target_images = batch.target_images if batch.target_images is not None else batch.images
        disp_images = vis.to_disp_images(target_images[:nimgs], denorm=True)

        # draw GAN score
        if self.args.with_gan:
            with torch.no_grad():
                err_gan_inputs = self.saae.D(batch.images[:nimgs])
            disp_images = vis.add_error_to_images(disp_images,
                                                  errors=1 - err_gan_inputs,
                                                  loc=loc_err_gan,
                                                  format_string='{:>5.2f}',
                                                  vmax=1.0)

        # disp_images = vis.add_landmarks_to_images(disp_images, batch.landmarks[:nimgs], color=(0,1,0), radius=1,
        #                                           draw_wireframe=False)
        rows = [vis.make_grid(disp_images, nCols=nimgs, normalize=False)]

        recon_images = vis.to_disp_images(X_recon[:nimgs], denorm=True)
        disp_X_recon = recon_images.copy()

        print_stats = True
        if print_stats:
            # lm_ssim_errs = None
            # if batch.landmarks is not None:
            #     lm_recon_errs = lmutils.calc_landmark_recon_error(batch.images[:nimgs], X_recon[:nimgs], batch.landmarks[:nimgs], reduction='none')
            #     disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_recon_errs, size=text_size_errors, loc='bm',
            #                                            format_string='({:>3.1f})', vmin=0, vmax=10)
            #     lm_ssim_errs = lmutils.calc_landmark_ssim_error(batch.images[:nimgs], X_recon[:nimgs], batch.landmarks[:nimgs])
            #     disp_X_recon = vis.add_error_to_images(disp_X_recon, lm_ssim_errs.mean(axis=1), size=text_size_errors, loc='bm-1',
            #                                            format_string='({:>3.2f})', vmin=0.2, vmax=0.8)

            X_recon_errs = 255.0 * torch.abs(batch.images - X_recon).reshape(
                len(batch.images), -1).mean(dim=1)
            # disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, batch.landmarks[:nimgs], radius=1, color=None,
            #                                            lm_errs=lm_ssim_errs, draw_wireframe=False)
            disp_X_recon = vis.add_error_to_images(disp_X_recon[:nimgs],
                                                   errors=X_recon_errs,
                                                   size=text_size_errors,
                                                   format_string='{:>4.1f}')
            if self.args.with_gan:
                with torch.no_grad():
                    err_gan = self.saae.D(X_recon[:nimgs])
                disp_X_recon = vis.add_error_to_images(
                    disp_X_recon,
                    errors=1 - err_gan,
                    loc=loc_err_gan,
                    format_string='{:>5.2f}',
                    vmax=1.0)

            ssim = np.zeros(nimgs)
            for i in range(nimgs):
                data_range = 255.0 if input_images[0].dtype == np.uint8 else 1.0
                ssim[i] = compare_ssim(input_images[i],
                                       recon_images[i],
                                       data_range=data_range,
                                       multichannel=True)
            disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                                   1 - ssim,
                                                   loc='bl-1',
                                                   size=text_size_errors,
                                                   format_string='{:>4.2f}',
                                                   vmin=0.2,
                                                   vmax=0.8)

            if ssim_maps is not None:
                disp_X_recon = vis.add_error_to_images(
                    disp_X_recon,
                    ssim_maps.reshape(len(ssim_maps), -1).mean(axis=1),
                    size=text_size_errors,
                    loc='bl-2',
                    format_string='{:>4.2f}',
                    vmin=0.0,
                    vmax=0.4)

        rows.append(vis.make_grid(disp_X_recon, nCols=nimgs))

        if ssim_maps is not None:
            disp_ssim_maps = to_numpy(
                nn.denormalized(ssim_maps)[:nimgs].transpose(0, 2, 3, 1))
            for i in range(len(disp_ssim_maps)):
                disp_ssim_maps[i] = vis.color_map(
                    disp_ssim_maps[i].mean(axis=2), vmin=0.0, vmax=2.0)
            grid_ssim_maps = vis.make_grid(disp_ssim_maps, nCols=nimgs)
            cv2.imshow('ssim errors',
                       cv2.cvtColor(grid_ssim_maps, cv2.COLOR_RGB2BGR))

        self.saae.D.train(train_state_D)
        self.saae.Q.train(train_state_Q)
        self.saae.P.train(train_state_P)

        f = 1
        disp_rows = vis.make_grid(rows, nCols=1, normalize=False, fx=f, fy=f)
        wnd_title = 'recon errors '
        if ds is not None:
            wnd_title += ds.__class__.__name__
        cv2.imshow(wnd_title, cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR))
        cv2.waitKey(wait)
示例#17
0
 def z_vecs(self):
     return [to_numpy(self.z)]
示例#18
0
 def detect_landmarks(self, X):
     X_recon = self.forward(X)
     X_lm_hm = self.LMH(self.P)
     X_lm_hm = landmarks.lmutils.smooth_heatmaps(X_lm_hm)
     lm_preds = to_numpy(self.heatmaps_to_landmarks(X_lm_hm))
     return X_recon, lm_preds, X_lm_hm
示例#19
0
def calculate_metrics(preds,
                      gt_vessels,
                      fov_masks=None,
                      full_eval=False,
                      verbose=False):

    assert len(preds) == len(gt_vessels)

    if not isinstance(preds, np.ndarray):
        preds = to_numpy(preds)

    if not isinstance(gt_vessels, np.ndarray):
        gt_vessels = to_numpy(gt_vessels)

    assert isinstance(gt_vessels, np.ndarray)

    if len(gt_vessels.shape) == 2:
        gt_vessels, preds = gt_vessels[np.newaxis], preds[np.newaxis]

    if fov_masks is not None:
        if not isinstance(fov_masks, np.ndarray):
            fov_masks = np.array(fov_masks)
        if len(fov_masks.shape) == 2:
            fov_masks = fov_masks[np.newaxis]
        gt_vessels_in_mask, pred_vessels_in_mask = pixel_values_in_mask(
            gt_vessels, preds, fov_masks)
    else:
        gt_vessels_in_mask, pred_vessels_in_mask = gt_vessels, preds

    y_true = to_numpy(gt_vessels_in_mask).ravel() >= 1
    y_score = to_numpy(pred_vessels_in_mask).ravel()

    precision, recall, thresholds = precision_recall_curve(y_true, y_score)

    precision = np.fliplr([
        precision
    ])[0]  # so the array is increasing (you won't get negative AUC)
    recall = np.fliplr(
        [recall])[0]  # so the array is increasing (you won't get negative AUC)
    thresholds = np.fliplr([thresholds])[0]
    AUC_prec_rec = np.trapz(precision, recall)
    average_precision = AUC_prec_rec

    results = {}
    results['PR'] = average_precision

    if full_eval:

        best_f1, best_f1_th = best_f1_threshold(precision, recall, thresholds)
        results['F1'] = best_f1
        results['F1_th'] = best_f1_th

        fpr, tpr, _ = roc_curve(y_true, y_score)
        roc = auc(fpr, tpr)
        results['ROC'] = roc

        otsu_threshold = filters.threshold_otsu(pred_vessels_in_mask)
        y_pred_bin = pred_vessels_in_mask >= otsu_threshold
        acc, se, sp, f1 = misc_measures_evaluation(y_true, y_pred_bin)
        results['otsu_th'] = otsu_threshold
        results['otsu_SE'] = se
        results['otsu_SP'] = sp
        results['otsu_ACC'] = acc
        results['otsu_F1'] = f1

        fixed_threshold = 0.5
        y_pred_bin = pred_vessels_in_mask >= fixed_threshold
        acc, se, sp, f1 = misc_measures_evaluation(y_true, y_pred_bin)
        results['th_SE'] = se
        results['th_SP'] = sp
        results['th_ACC'] = acc
        results['th_F1'] = f1

        if verbose:
            print(f"F1 score : {best_f1:.4f} (th={best_f1_th:.3f})")
            print(f"F1 score : {f1:.4f} (th={fixed_threshold:.3f})")
            print(
                f"SE/SP/ACC: {se:.4f}, {sp:.4f}, {acc:.4f} (th={fixed_threshold:.3f})"
            )
            print('AUC PR: {0:0.4f}'.format(average_precision))
            print('AUC ROC: {0:0.4f}'.format(roc))

    return results
示例#20
0
def visualize_batch_CVPR(images,
                         landmarks,
                         X_recon,
                         X_lm_hm,
                         lm_preds,
                         show_recon=True,
                         lm_heatmaps=None,
                         ds=None,
                         wait=0,
                         horizontal=False,
                         f=1.0,
                         radius=2,
                         draw_wireframes=False):

    gt_color = (0, 255, 0)
    pred_color = (0, 255, 255)
    # pred_color = (255,0,0)

    nimgs = min(10, len(images))
    images = nn.atleast4d(images)[:nimgs]
    num_landmarks = lm_preds.shape[1]

    # if landmarks is None:
    #     print('num landmarks', num_landmarks)
    #     lm_gt = np.zeros((nimgs, num_landmarks, 2))
    # else:

    # show landmark heatmaps
    pred_heatmaps = None
    if X_lm_hm is not None:
        pred_heatmaps = to_single_channel_heatmap(to_numpy(X_lm_hm[:nimgs]))
        pred_heatmaps = [
            cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
            for im in pred_heatmaps
        ]
        gt_heatmaps = None
        if lm_heatmaps is not None:
            gt_heatmaps = to_single_channel_heatmap(
                to_numpy(lm_heatmaps[:nimgs]))
            gt_heatmaps = np.array([
                cv2.resize(im,
                           None,
                           fx=f,
                           fy=f,
                           interpolation=cv2.INTER_NEAREST)
                for im in gt_heatmaps
            ])
        show_landmark_heatmaps(pred_heatmaps, gt_heatmaps, nimgs, f=1)
        lm_confs = to_numpy(X_lm_hm).reshape(X_lm_hm.shape[0],
                                             X_lm_hm.shape[1], -1).max(axis=2)

    # resize images for display and scale landmarks accordingly
    lm_preds = lm_preds[:nimgs] * f

    rows = []

    disp_images = vis.to_disp_images(images[:nimgs], denorm=True)
    disp_images = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in disp_images
    ]
    rows.append(vis.make_grid(disp_images, nCols=nimgs, normalize=False))

    heatmap_opacity = 1.0
    if show_recon:
        recon_images = vis.to_disp_images(X_recon[:nimgs], denorm=True)
    else:
        recon_images = vis.to_disp_images(torch.ones_like(X_recon[:nimgs]),
                                          denorm=False)
        heatmap_opacity = 1

    disp_X_recon = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in recon_images.copy()
    ]
    rows.append(vis.make_grid(disp_X_recon, nCols=nimgs))

    # overlay landmarks on images
    disp_X_recon_hm = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in recon_images.copy()
    ]
    disp_X_recon_hm = [
        vis.overlay_heatmap(disp_X_recon_hm[i], pred_heatmaps[i],
                            heatmap_opacity) for i in range(len(pred_heatmaps))
    ]
    rows.append(vis.make_grid(disp_X_recon_hm, nCols=nimgs))

    # reconstructions with prediction
    disp_X_recon_pred = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in recon_images.copy()
    ]
    disp_X_recon_pred = vis.add_landmarks_to_images(disp_X_recon_pred,
                                                    lm_preds,
                                                    color=pred_color,
                                                    radius=radius)
    rows.append(vis.make_grid(disp_X_recon_pred, nCols=nimgs))

    # reconstructions with ground truth (if gt available)
    if landmarks is not None:
        lm_gt = nn.atleast3d(to_numpy(landmarks))[:nimgs] * f
        disp_X_recon_gt = [
            cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
            for im in recon_images.copy()
        ]
        disp_X_recon_gt = vis.add_landmarks_to_images(disp_X_recon_gt,
                                                      lm_gt,
                                                      color=gt_color,
                                                      radius=radius)
        rows.append(vis.make_grid(disp_X_recon_gt, nCols=nimgs))

    # input images with prediction (and ground truth)
    disp_images_pred = vis.to_disp_images(images[:nimgs], denorm=True)
    disp_images_pred = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in disp_images_pred
    ]
    # disp_images_pred = vis.add_landmarks_to_images(disp_images_pred, lm_gt, color=gt_color, radius=radius)
    disp_images_pred = vis.add_landmarks_to_images(
        disp_images_pred,
        lm_preds,
        color=pred_color,
        radius=radius,
        draw_wireframe=draw_wireframes)
    rows.append(vis.make_grid(disp_images_pred, nCols=nimgs))

    if horizontal:
        assert (nimgs == 1)
        disp_rows = vis.make_grid(rows, nCols=len(rows))
    else:
        disp_rows = vis.make_grid(rows, nCols=1)
    wnd_title = 'recon errors '
    if ds is not None:
        wnd_title += ds.__class__.__name__
    cv2.imshow(wnd_title, cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR))
    cv2.waitKey(wait)
示例#21
0
        raise ValueError

    net.eval()

    results_probs = []
    gt_masks = []
    fov_masks = []

    results = []
    t_tot = 0

    for idx in range(len(dataset))[:]:

        data = dataset[idx]
        image_id = data['fname']
        full_image = to_numpy(data['image'])
        gt_mask = to_numpy(data['mask']) // 255
        fov_mask = dataset.fov_masks[image_id]

        print(f'\n---- Testing image {idx+1}: {image_id} ---- ')

        t = time.perf_counter()
        recon, probs = predict_vessels.segment_image(
            net, full_image, patch_size=args.patch_size,
            scales=scales[dsname])  #, gpu=args.gpu)

        # probs_lr = np.fliplr(predict_vessels.segment_image(net, np.fliplr(full_image), scales=scales[dsname], patch_size=args.patch_size)[1])
        # probs_ud = np.flipud(predict_vessels.segment_image(net, np.flipud(full_image), scales=scales[dsname], patch_size=args.patch_size)[1])
        # probs = (probs + probs_lr + probs_ud) / 3

        t_image = time.perf_counter() - t
示例#22
0
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)

    dirs = config.get_dataset_paths('affectnet')
    train = True
    ds = AffectNet(root=dirs[0],
                   image_size=256,
                   cache_root=dirs[1],
                   train=train,
                   use_cache=False,
                   transform=ds_utils.build_transform(deterministic=not train,
                                                      daug=0),
                   crop_source='lm_ground_truth')
    dl = td.DataLoader(ds, batch_size=10, shuffle=False, num_workers=0)
    # print(ds)

    for data in dl:
        batch = Batch(data, gpu=False)

        gt = to_numpy(batch.landmarks)
        ocular_dists_inner = np.sqrt(np.sum((gt[:, 42] - gt[:, 39])**2,
                                            axis=1))
        ocular_dists_outer = np.sqrt(np.sum((gt[:, 45] - gt[:, 36])**2,
                                            axis=1))
        ocular_dists = np.vstack(
            (ocular_dists_inner, ocular_dists_outer)).mean(axis=0)
        print(ocular_dists)

        images = vis.to_disp_images(batch.images, denorm=True)
        imgs = vis.add_landmarks_to_images(images, batch.landmarks.numpy())
        vis.vis_square(imgs, nCols=10, fx=1.0, fy=1.0, normalize=False)
示例#23
0
def visualize_batch(images,
                    landmarks,
                    X_recon,
                    X_lm_hm,
                    lm_preds_max,
                    lm_heatmaps=None,
                    target_images=None,
                    lm_preds_cnn=None,
                    ds=None,
                    wait=0,
                    ssim_maps=None,
                    landmarks_to_draw=None,
                    ocular_norm='outer',
                    horizontal=False,
                    f=1.0,
                    overlay_heatmaps_input=False,
                    overlay_heatmaps_recon=False,
                    clean=False,
                    landmarks_only_outline=range(17),
                    landmarks_no_outline=range(17, 68)):

    gt_color = (0, 255, 0)
    pred_color = (0, 0, 255)
    image_size = images.shape[3]
    assert image_size in [128, 256]

    nimgs = min(10, len(images))
    images = nn.atleast4d(images)[:nimgs]
    num_landmarks = lm_preds_max.shape[1]

    if landmarks_to_draw is None:
        landmarks_to_draw = range(num_landmarks)

    nme_per_lm = None
    if landmarks is None:
        # print('num landmarks', lmcfg.NUM_LANDMARKS)
        lm_gt = np.zeros((nimgs, num_landmarks, 2))
    else:
        lm_gt = nn.atleast3d(to_numpy(landmarks))[:nimgs]
        nme_per_lm = calc_landmark_nme(lm_gt,
                                       lm_preds_max[:nimgs],
                                       ocular_norm=ocular_norm,
                                       image_size=image_size)
        lm_ssim_errs = 1 - calc_landmark_ssim_score(images, X_recon[:nimgs],
                                                    lm_gt)

    lm_confs = None
    # show landmark heatmaps
    pred_heatmaps = None
    if X_lm_hm is not None:
        pred_heatmaps = to_single_channel_heatmap(to_numpy(X_lm_hm[:nimgs]))
        pred_heatmaps = [
            cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
            for im in pred_heatmaps
        ]
        gt_heatmaps = None
        if lm_heatmaps is not None:
            gt_heatmaps = to_single_channel_heatmap(
                to_numpy(lm_heatmaps[:nimgs]))
            gt_heatmaps = np.array([
                cv2.resize(im,
                           None,
                           fx=f,
                           fy=f,
                           interpolation=cv2.INTER_NEAREST)
                for im in gt_heatmaps
            ])
        show_landmark_heatmaps(pred_heatmaps, gt_heatmaps, nimgs, f=1)
        lm_confs = to_numpy(X_lm_hm).reshape(X_lm_hm.shape[0],
                                             X_lm_hm.shape[1], -1).max(axis=2)

    # resize images for display and scale landmarks accordingly
    lm_preds_max = lm_preds_max[:nimgs] * f
    if lm_preds_cnn is not None:
        lm_preds_cnn = lm_preds_cnn[:nimgs] * f
    lm_gt *= f

    input_images = vis.to_disp_images(images[:nimgs], denorm=True)
    if target_images is not None:
        disp_images = vis.to_disp_images(target_images[:nimgs], denorm=True)
    else:
        disp_images = vis.to_disp_images(images[:nimgs], denorm=True)
    disp_images = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in disp_images
    ]

    recon_images = vis.to_disp_images(X_recon[:nimgs], denorm=True)
    disp_X_recon = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in recon_images.copy()
    ]

    # overlay landmarks on input images
    if pred_heatmaps is not None and overlay_heatmaps_input:
        disp_images = [
            vis.overlay_heatmap(disp_images[i], pred_heatmaps[i])
            for i in range(len(pred_heatmaps))
        ]
    if pred_heatmaps is not None and overlay_heatmaps_recon:
        disp_X_recon = [
            vis.overlay_heatmap(disp_X_recon[i], pred_heatmaps[i])
            for i in range(len(pred_heatmaps))
        ]

    #
    # Show input images
    #
    disp_images = vis.add_landmarks_to_images(disp_images,
                                              lm_gt[:nimgs],
                                              color=gt_color)
    disp_images = vis.add_landmarks_to_images(disp_images,
                                              lm_preds_max[:nimgs],
                                              lm_errs=nme_per_lm,
                                              color=pred_color,
                                              draw_wireframe=False,
                                              gt_landmarks=lm_gt,
                                              draw_gt_offsets=True)

    # disp_images = vis.add_landmarks_to_images(disp_images, lm_gt[:nimgs], color=(1,1,1), radius=1,
    #                                           draw_dots=True, draw_wireframe=True, landmarks_to_draw=landmarks_to_draw)
    # disp_images = vis.add_landmarks_to_images(disp_images, lm_preds_max[:nimgs], lm_errs=nme_per_lm,
    #                                           color=(1.0, 0.0, 0.0),
    #                                           draw_dots=True, draw_wireframe=True, radius=1,
    #                                           gt_landmarks=lm_gt, draw_gt_offsets=False,
    #                                           landmarks_to_draw=landmarks_to_draw)

    #
    # Show reconstructions
    #
    X_recon_errs = 255.0 * torch.abs(images - X_recon[:nimgs]).reshape(
        len(images), -1).mean(dim=1)
    if not clean:
        disp_X_recon = vis.add_error_to_images(disp_X_recon[:nimgs],
                                               errors=X_recon_errs,
                                               format_string='{:>4.1f}')

    # modes of heatmaps
    # disp_X_recon = [overlay_heatmap(disp_X_recon[i], pred_heatmaps[i]) for i in range(len(pred_heatmaps))]
    if not clean:
        lm_errs_max = calc_landmark_nme_per_img(lm_gt,
                                                lm_preds_max,
                                                ocular_norm,
                                                landmarks_no_outline,
                                                image_size=image_size)
        lm_errs_max_outline = calc_landmark_nme_per_img(lm_gt,
                                                        lm_preds_max,
                                                        ocular_norm,
                                                        landmarks_only_outline,
                                                        image_size=image_size)
        lm_errs_max_all = calc_landmark_nme_per_img(
            lm_gt,
            lm_preds_max,
            ocular_norm,
            list(landmarks_only_outline) + list(landmarks_no_outline),
            image_size=image_size)
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               lm_errs_max,
                                               loc='br-2',
                                               format_string='{:>5.2f}',
                                               vmax=15)
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               lm_errs_max_outline,
                                               loc='br-1',
                                               format_string='{:>5.2f}',
                                               vmax=15)
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               lm_errs_max_all,
                                               loc='br',
                                               format_string='{:>5.2f}',
                                               vmax=15)
    disp_X_recon = vis.add_landmarks_to_images(disp_X_recon,
                                               lm_gt,
                                               color=gt_color,
                                               draw_wireframe=True)

    # disp_X_recon = vis.add_landmarks_to_images(disp_X_recon, lm_preds_max[:nimgs],
    #                                            color=pred_color, draw_wireframe=False,
    #                                            lm_errs=nme_per_lm, lm_confs=lm_confs,
    #                                            lm_rec_errs=lm_ssim_errs, gt_landmarks=lm_gt,
    #                                            draw_gt_offsets=True, draw_dots=True)

    disp_X_recon = vis.add_landmarks_to_images(disp_X_recon,
                                               lm_preds_max[:nimgs],
                                               color=pred_color,
                                               draw_wireframe=True,
                                               gt_landmarks=lm_gt,
                                               draw_gt_offsets=True,
                                               lm_errs=nme_per_lm,
                                               draw_dots=True,
                                               radius=2)

    def add_confidences(disp_X_recon, lmids, loc):
        means = lm_confs[:, lmids].mean(axis=1)
        colors = vis.color_map(to_numpy(1 - means),
                               cmap=plt.cm.jet,
                               vmin=0.0,
                               vmax=0.4)
        return vis.add_error_to_images(disp_X_recon,
                                       means,
                                       loc=loc,
                                       format_string='{:>4.2f}',
                                       colors=colors)

    # disp_X_recon = add_confidences(disp_X_recon, lmcfg.LANDMARKS_NO_OUTLINE, 'bm-2')
    # disp_X_recon = add_confidences(disp_X_recon, lmcfg.LANDMARKS_ONLY_OUTLINE, 'bm-1')
    # disp_X_recon = add_confidences(disp_X_recon, lmcfg.ALL_LANDMARKS, 'bm')

    # print ssim errors
    ssim = np.zeros(nimgs)
    for i in range(nimgs):
        ssim[i] = compare_ssim(input_images[i],
                               recon_images[i],
                               data_range=1.0,
                               multichannel=True)
    if not clean:
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               1 - ssim,
                                               loc='bl-1',
                                               format_string='{:>4.2f}',
                                               vmax=0.8,
                                               vmin=0.2)
    # print ssim torch errors
    if ssim_maps is not None and not clean:
        disp_X_recon = vis.add_error_to_images(disp_X_recon,
                                               ssim_maps.reshape(
                                                   len(ssim_maps),
                                                   -1).mean(axis=1),
                                               loc='bl-2',
                                               format_string='{:>4.2f}',
                                               vmin=0.0,
                                               vmax=0.4)

    rows = [vis.make_grid(disp_images, nCols=nimgs, normalize=False)]
    rows.append(vis.make_grid(disp_X_recon, nCols=nimgs))

    if ssim_maps is not None:
        disp_ssim_maps = to_numpy(
            nn.denormalized(ssim_maps)[:nimgs].transpose(0, 2, 3, 1))
        for i in range(len(disp_ssim_maps)):
            disp_ssim_maps[i] = vis.color_map(disp_ssim_maps[i].mean(axis=2),
                                              vmin=0.0,
                                              vmax=2.0)
        grid_ssim_maps = vis.make_grid(disp_ssim_maps, nCols=nimgs, fx=f, fy=f)
        cv2.imshow('ssim errors',
                   cv2.cvtColor(grid_ssim_maps, cv2.COLOR_RGB2BGR))

    if horizontal:
        assert (nimgs == 1)
        disp_rows = vis.make_grid(rows, nCols=2)
    else:
        disp_rows = vis.make_grid(rows, nCols=1)
    wnd_title = 'Predicted Landmarks '
    if ds is not None:
        wnd_title += ds.__class__.__name__
    cv2.imshow(wnd_title, cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR))
    cv2.waitKey(wait)
示例#24
0
def visualize_vessels(images,
                      X_recon,
                      vessel_hm,
                      pred_vessel_hm=None,
                      ds=None,
                      wait=0,
                      horizontal=False,
                      f=1.0,
                      overlay_heatmaps_input=True,
                      overlay_heatmaps_recon=True,
                      scores=None,
                      nimgs=5):

    nimgs = min(nimgs, len(images))
    images = images[:nimgs]
    rows = []

    input_images = vis.to_disp_images(images[:nimgs], denorm=True)
    disp_images = vis.to_disp_images(images[:nimgs], denorm=True)
    disp_images = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in disp_images
    ]
    rows.append(vis.make_grid(disp_images, nCols=nimgs, normalize=False))

    recon_images = vis.to_disp_images(X_recon[:nimgs], denorm=True)
    disp_X_recon = [
        cv2.resize(im, None, fx=f, fy=f, interpolation=cv2.INTER_NEAREST)
        for im in recon_images.copy()
    ]

    if vessel_hm is not None and overlay_heatmaps_input:
        vessel_hm = to_numpy(vessel_hm[:nimgs])
        disp_images = [
            vis.overlay_heatmap(disp_images[i], vessel_hm[i, 0], 0.5)
            for i in range(len(vessel_hm))
        ]

    rows.append(vis.make_grid(disp_images, nCols=nimgs, normalize=False))

    if pred_vessel_hm is not None and overlay_heatmaps_recon:
        pred_vessel_hm = to_numpy(pred_vessel_hm[:nimgs])
        disp_X_recon_overlay = [
            vis.overlay_heatmap(disp_X_recon[i], pred_vessel_hm[i, 0], 1.0)
            for i in range(len(pred_vessel_hm))
        ]
        if scores is not None:
            disp_X_recon_overlay = vis.add_error_to_images(
                disp_X_recon_overlay, scores, loc='tr', format_string='{:.3f}')
        rows.append(vis.make_grid(disp_X_recon_overlay, nCols=nimgs))

    rows.append(vis.make_grid(disp_X_recon, nCols=nimgs))

    if horizontal:
        assert (nimgs == 1)
        disp_rows = vis.make_grid(rows, nCols=4)
    else:
        disp_rows = vis.make_grid(rows, nCols=1)

    wnd_title = 'Predicted vessels '
    if ds is not None:
        wnd_title += ds.__class__.__name__
    cv2.imshow(wnd_title, cv2.cvtColor(disp_rows, cv2.COLOR_RGB2BGR))
    cv2.waitKey(wait)
示例#25
0
文件: vis.py 项目: walaa5/Face-X
def add_landmarks_to_images(images, landmarks, color=None, radius=2, gt_landmarks=None,
                            lm_errs=None, lm_confs=None, lm_rec_errs=None,
                            draw_dots=True, draw_wireframe=False, draw_gt_offsets=False, landmarks_to_draw=None,
                            offset_line_color=None):

    def draw_wireframe_lines(img, lms):
        pts = lms.reshape((-1,1,2)).astype(np.int32)
        cv2.polylines(img, [pts[:17]], isClosed=False, color=color, lineType=cv2.LINE_AA)  # head outline
        cv2.polylines(img, [pts[17:22]], isClosed=False, color=color, lineType=cv2.LINE_AA)  # left eyebrow
        cv2.polylines(img, [pts[22:27]], isClosed=False, color=color, lineType=cv2.LINE_AA)  # right eyebrow
        cv2.polylines(img, [pts[27:31]], isClosed=False, color=color, lineType=cv2.LINE_AA)  # nose vert
        cv2.polylines(img, [pts[31:36]], isClosed=False, color=color, lineType=cv2.LINE_AA)  # nose hor
        cv2.polylines(img, [pts[36:42]], isClosed=True, color=color, lineType=cv2.LINE_AA)  # left eye
        cv2.polylines(img, [pts[42:48]], isClosed=True, color=color, lineType=cv2.LINE_AA)  # right eye
        cv2.polylines(img, [pts[48:60]], isClosed=True, color=color, lineType=cv2.LINE_AA)  # outer mouth
        cv2.polylines(img, [pts[60:68]], isClosed=True, color=color, lineType=cv2.LINE_AA)  # inner mouth

    def draw_wireframe_lines_98(img, lms):
        pts = lms.reshape((-1,1,2)).astype(np.int32)
        cv2.polylines(img, [pts[:33]], isClosed=False, color=color, lineType=cv2.LINE_AA)  # head outline
        cv2.polylines(img, [pts[33:42]], isClosed=True, color=color, lineType=cv2.LINE_AA)  # left eyebrow
        # cv2.polylines(img, [pts[38:42]], isClosed=False, color=color, lineType=cv2.LINE_AA)  # right eyebrow
        cv2.polylines(img, [pts[42:51]], isClosed=True, color=color, lineType=cv2.LINE_AA)  # nose vert
        cv2.polylines(img, [pts[51:55]], isClosed=False, color=color, lineType=cv2.LINE_AA)  # nose hor
        cv2.polylines(img, [pts[55:60]], isClosed=False, color=color, lineType=cv2.LINE_AA)  # left eye
        cv2.polylines(img, [pts[60:68]], isClosed=True, color=color, lineType=cv2.LINE_AA)  # right eye
        cv2.polylines(img, [pts[68:76]], isClosed=True, color=color, lineType=cv2.LINE_AA)  # outer mouth
        cv2.polylines(img, [pts[76:88]], isClosed=True, color=color, lineType=cv2.LINE_AA)  # inner mouth
        cv2.polylines(img, [pts[88:96]], isClosed=True, color=color, lineType=cv2.LINE_AA)  # inner mouth

    def draw_offset_lines(img, lms, gt_lms, errs):
        if gt_lms.sum() == 0:
            return
        if lm_errs is None:
            # if offset_line_color is None:
            offset_line_color = (1,1,1)
            colors = [offset_line_color] * len(lms)
        else:
            colors = color_map(errs, cmap=plt.cm.jet, vmin=0, vmax=15.0)
        if img.dtype == np.uint8:
            colors *= 255
        for i, (p1, p2) in enumerate(zip(lms, gt_lms)):
            if landmarks_to_draw is None or i in landmarks_to_draw:
                if p1.min() > 0:
                    cv2.line(img, tuple(p1.astype(int)), tuple(p2.astype(int)), colors[i], thickness=1, lineType=cv2.LINE_AA)

    new_images = to_disp_images(images)
    landmarks = to_numpy(landmarks)
    gt_landmarks = to_numpy(gt_landmarks)
    lm_errs = to_numpy(lm_errs)
    img_size = new_images[0].shape[0]
    default_color = (255,255,255)

    if gt_landmarks is not None and draw_gt_offsets:
        for img_id  in range(len(new_images)):
            if gt_landmarks[img_id].sum() == 0:
                continue
            dists = None
            if lm_errs is not None:
                dists = lm_errs[img_id]
            draw_offset_lines(new_images[img_id], landmarks[img_id], gt_landmarks[img_id], dists)

    for img_id, (disp, lm)  in enumerate(zip(new_images, landmarks)):
        if len(lm) in [68, 21, 19, 98, 8, 5, 38]:
            if draw_dots:
                for lm_id in range(0,len(lm)):
                    if landmarks_to_draw is None or lm_id in landmarks_to_draw or len(lm) != 68:
                        lm_color = color
                        if lm_color is None:
                            if lm_errs is not None:
                                lm_color = color_map(lm_errs[img_id, lm_id], cmap=plt.cm.jet, vmin=0, vmax=1.0)
                            else:
                                lm_color = default_color
                        # if lm_errs is not None and lm_errs[img_id, lm_id] > 40.0:
                        #     lm_color = (1,0,0)
                        cv2.circle(disp, tuple(lm[lm_id].astype(int).clip(0, disp.shape[0]-1)), radius=radius, color=lm_color, thickness=-1, lineType=cv2.LINE_AA)
                        if lm_confs is not None:
                            max_radius = img_size * 0.05
                            try:
                                conf_radius = max(2, int((1-lm_confs[img_id, lm_id]) * max_radius))
                            except ValueError:
                                conf_radius = 2
                            # if lm_confs[img_id, lm_id] > 0.4:
                            cirle_color = (0,0,255)
                            # if lm_confs[img_id, lm_id] < is_good_landmark(lm_confs, lm_rec_errs):
                            # if not is_good_landmark(lm_confs[img_id, lm_id], lm_rec_errs[img_id, lm_id]):
                            if lm_errs[img_id, lm_id] > 10.0:
                                cirle_color = (255,0,0)
                            cv2.circle(disp, tuple(lm[lm_id].astype(int)), conf_radius, cirle_color, 1, lineType=cv2.LINE_AA)

            # Draw outline if we actually have 68 valid landmarks.
            # Landmarks can be zeros for UMD landmark format (21 points).
            if draw_wireframe:
                nlms = (np.count_nonzero(lm.sum(axis=1)))
                if nlms == 68:
                    draw_wireframe_lines(disp, lm)
                elif nlms == 98:
                    draw_wireframe_lines_98(disp, lm)
        else:
            # colors = ['tab:gray', 'tab:orange', 'tab:brown', 'tab:pink', 'tab:cyan', 'tab:olive', 'tab:red', 'tab:blue']
            # colors_rgb = list(map(plt_colors.to_rgb, colors))

            colors = sns.color_palette("Set1", n_colors=14)
            for i in range(0,len(lm)):
                cv2.circle(disp, tuple(lm[i].astype(int)), radius=radius, color=colors[i], thickness=2, lineType=cv2.LINE_AA)
    return new_images
示例#26
0
文件: lmutils.py 项目: oxyai/3FabRec
def get_landmark_confs(X_lm_hm):
    return np.clip(to_numpy(X_lm_hm).reshape(X_lm_hm.shape[0],
                                             X_lm_hm.shape[1], -1).max(axis=2),
                   a_min=0,
                   a_max=1)
示例#27
0
文件: lmutils.py 项目: oxyai/3FabRec
 def reformat(lms):
     lms = to_numpy(lms)
     if len(lms.shape) == 2:
         lms = lms.reshape((1, -1, 2))
     return lms
示例#28
0
    def _run_batch(self, data, eval=False, ds=None):
        time_dataloading = time.time() - self.iter_starttime
        time_proc_start = time.time()
        iter_stats = {'time_dataloading': time_dataloading}

        batch = Batch(data, eval=eval)

        self.saae.zero_grad()
        self.saae.eval()

        input_images = batch.target_images if batch.target_images is not None else batch.images

        with torch.set_grad_enabled(self.args.train_encoder):
            z_sample = self.saae.Q(input_images)

        iter_stats.update({'z_recon_mean': z_sample.mean().item()})

        #######################
        # Reconstruction phase
        #######################
        with torch.set_grad_enabled(self.args.train_encoder and not eval):
            X_recon = self.saae.P(z_sample)

        # calculate reconstruction error for debugging and reporting
        with torch.no_grad():
            iter_stats['loss_recon'] = aae_training.loss_recon(
                batch.images, X_recon)

        #######################
        # Landmark predictions
        #######################
        train_lmhead = not eval
        lm_preds_max = None
        with torch.set_grad_enabled(train_lmhead):
            self.saae.LMH.train(train_lmhead)
            X_lm_hm = self.saae.LMH(self.saae.P)
            if batch.lm_heatmaps is not None:
                loss_lms = F.mse_loss(batch.lm_heatmaps, X_lm_hm) * 100 * 3
                iter_stats.update({'loss_lms': loss_lms.item()})

            if eval or self._is_printout_iter(eval):
                # expensive, so only calculate when every N iterations
                # X_lm_hm = lmutils.decode_heatmap_blob(X_lm_hm)
                X_lm_hm = lmutils.smooth_heatmaps(X_lm_hm)
                lm_preds_max = self.saae.heatmaps_to_landmarks(X_lm_hm)

            if eval or self._is_printout_iter(eval):
                lm_gt = to_numpy(batch.landmarks)
                nmes = lmutils.calc_landmark_nme(
                    lm_gt,
                    lm_preds_max,
                    ocular_norm=self.args.ocular_norm,
                    image_size=self.args.input_size)
                # nccs = lmutils.calc_landmark_ncc(batch.images, X_recon, lm_gt)
                iter_stats.update({'nmes': nmes})

        if train_lmhead:
            # if self.args.train_encoder:
            #     loss_lms = loss_lms * 80.0
            loss_lms.backward()
            self.optimizer_lm_head.step()
            if self.args.train_encoder:
                self.optimizer_E.step()
                # self.optimizer_G.step()

        # statistics
        iter_stats.update({
            'epoch': self.epoch,
            'timestamp': time.time(),
            'iter_time': time.time() - self.iter_starttime,
            'time_processing': time.time() - time_proc_start,
            'iter': self.iter_in_epoch,
            'total_iter': self.total_iter,
            'batch_size': len(batch)
        })
        self.iter_starttime = time.time()

        self.epoch_stats.append(iter_stats)

        # print stats every N mini-batches
        if self._is_printout_iter(eval):
            self._print_iter_stats(
                self.epoch_stats[-self._print_interval(eval):])

            lmvis.visualize_batch(
                batch.images,
                batch.landmarks,
                X_recon,
                X_lm_hm,
                lm_preds_max,
                lm_heatmaps=batch.lm_heatmaps,
                target_images=batch.target_images,
                ds=ds,
                ocular_norm=self.args.ocular_norm,
                clean=False,
                overlay_heatmaps_input=False,
                overlay_heatmaps_recon=False,
                landmarks_only_outline=self.landmarks_only_outline,
                landmarks_no_outline=self.landmarks_no_outline,
                f=1.0,
                wait=self.wait)