コード例 #1
0
ファイル: 6_preview_outliers.py プロジェクト: jphdotam/T1T2
    seq_dir, npy_name = outlier.split('__')
    date_dir = dates_for_studies[seq_dir]

    t1w, t2w, pd, t1_raw, t2_raw = load_npy_file(
        os.path.join(TEST_DICOM_DIR, date_dir, seq_dir, npy_name))

    t1_pre, t1_post, t2, t1w, t2w, pd, t1_t2 = get_normalized_channel_stack(
        t1_raw, t2_raw, t1w, t2w, pd, data_stack_format='all')

    x = prep_normalized_stack_for_inference(t1_t2,
                                            FOV,
                                            as_tensor=True,
                                            tensor_device=DEVICE)

    # Landmark detection
    t2w_landmark, _top_left_landmark = center_crop(pad_if_needed(
        t2w, min_height=FOV, min_width=FOV),
                                                   crop_height=FOV,
                                                   crop_width=FOV)
    landmark_points, landmark_probs = perform_cmr_landmark_detection(
        t2w_landmark, model=landmark_model)

    if np.any(landmark_points == -1):
        print(f"Skipping {npy_name} - unable to identify all landmarks")
        continue

    landmark_points = extend_landmarks(landmark_points, FOV)

    with torch.no_grad():
        pred = tta(model, x).cpu().numpy()[0]

    (xs_epi, ys_epi), (xs_end, ys_end) = get_paths(pred, landmark_points)
コード例 #2
0
        mask_lvwall, mask_lvcav = np.zeros((h, w), dtype=np.uint8), np.zeros(
            (h, w), dtype=np.uint8)
        color = np.uint8(np.ones(3) * 1).tolist()
        points_end = [[x, h - y] for x, y in npy_label['endo']]
        points_epi = [[x, h - y] for x, y in npy_label['epi']]
        points_end = np.rint([points_end]).astype(np.int32)
        points_epi = np.rint([points_epi]).astype(np.int32)
        if points_end.size == 0 or points_epi.size == 0:
            print(f"Skipping {label_path} as missing points for >= 1 surface")
            continue
        cv2.fillPoly(mask_lvcav, points_end, color)
        cv2.fillPoly(mask_lvwall, points_epi, color)

        # crop mask
        mask_lvcav, _top_left = center_crop(pad_if_needed(mask_lvcav,
                                                          min_height=FOV,
                                                          min_width=FOV),
                                            crop_height=FOV,
                                            crop_width=FOV)
        mask_lvwall, _ = center_crop(pad_if_needed(mask_lvwall,
                                                   min_height=FOV,
                                                   min_width=FOV),
                                     crop_height=FOV,
                                     crop_width=FOV)

        # Find RV points so we can do the segmentation
        label_filename, seq_dir, date_dir = label_path.split(os.sep)[::-1][:3]
        image_path = os.path.join(TEST_DICOM_DIR, date_dir, seq_dir,
                                  label_filename.split('_HUMAN', 1)[0])
        t1w, t2w, pd, t1, t2 = np.load(image_path,
                                       allow_pickle=True).transpose((2, 0, 1))
コード例 #3
0
ファイル: vis.py プロジェクト: jphdotam/T1T2
def vis_pose(dataloader, model, epoch, cfg):
    def resize_for_vis(img, vis_res, is_mask):
        return skimage.transform.resize(img,
                                        vis_res,
                                        order=0 if is_mask else 1)

    def stick_posemap_on_frame(frame, posemap):
        posemap = posemap.transpose((1, 2, 0))
        if posemap.shape[:2] != frame.shape[:2]:
            posemap = resize_for_vis(posemap, frame.shape[:2], is_mask=True)
        img = np.dstack((frame, frame, frame))
        img[:, :, 0] = img[:, :, 0] + posemap[:, :, 0]
        img[:, :, 1] = img[:, :, 1] + posemap[:, :, 1]
        img = np.clip(img, 0, 1)
        return img

    if epoch % cfg['output']['vis_every']:
        return

    vis_n = cfg['output']['vis_n']
    vis_res = cfg['output']['vis_res']
    device = cfg['training']['device']
    landmark_model_path = cfg['export']['landmark_model_path']
    mask_classes = cfg['output']['mask_classes']

    batch_x, batch_y_true, batch_filepaths = next(iter(dataloader))
    with torch.no_grad():
        batch_y_pred = model(batch_x.to(device))
        if type(batch_y_pred) == OrderedDict:
            batch_y_pred = batch_y_pred['out']

    images = []
    masks = []

    landmark_model = load_landmark_model(landmark_model_path)
    for i, (frame, _y_true, y_pred, filepath) in enumerate(
            zip(batch_x, batch_y_true, batch_y_pred, batch_filepaths)):

        pred_np = y_pred.cpu().numpy()

        frame_t1_pre = resize_for_vis(frame[0], vis_res, False)
        frame_t1_post = resize_for_vis(frame[1], vis_res, False)
        frame_t2 = resize_for_vis(frame[2], vis_res, False)

        # heatmaps
        img_t1pre_pred = stick_posemap_on_frame(frame_t1_pre, pred_np)
        img_t1post_pred = stick_posemap_on_frame(frame_t1_post, pred_np)
        img_t2_pred = stick_posemap_on_frame(frame_t2, pred_np)
        img_plain = cv2.cvtColor(
            np.concatenate((frame_t1_pre, frame_t1_post, frame_t2), axis=1),
            cv2.COLOR_GRAY2RGB)
        img_heatmaps = np.concatenate(
            (img_t1pre_pred, img_t1post_pred, img_t2_pred), axis=1)
        img_withoutmask = np.concatenate((img_plain, img_heatmaps), axis=0)

        # landmark detection
        orig_npy_path = get_original_npy_path_from_exported_npz_path(
            filepath, cfg['export']['dicom_path_trainval'])
        t1w, t2w, pd, t1, t2 = np.transpose(np.load(orig_npy_path), (2, 0, 1))
        t2w_landmark, _top_left_landmark = center_crop(pad_if_needed(
            t2w, min_height=256, min_width=256),
                                                       crop_height=256,
                                                       crop_width=256)
        landmark_points, landmark_probs = perform_cmr_landmark_detection(
            t2w_landmark, model=landmark_model)

        # rv masks
        rvi1_xy, rvi2_xy, lv_xy = landmark_points
        rvimid_xy = 0.5 * (rvi1_xy + rvi2_xy)
        rv_xy = lv_xy + 2 * (rvimid_xy - lv_xy)
        mask_rvi1 = np.zeros_like(t2w_landmark)
        mask_rvi1[int(round(rvi1_xy[1])), int(round(rvi1_xy[0]))] = 1
        mask_rvmid = np.zeros_like(t2w_landmark)
        mask_rvmid[int(round(rv_xy[1])), int(round(rv_xy[0]))] = 1

        # Lv ridge tracing using landmarks
        if np.all(landmark_points == -1):
            print(f"Was unable to find landmarks on sample {i}")
            (xs_epi, ys_epi), (xs_end, ys_end) = [[], []]
        else:
            (xs_epi, ys_epi), (xs_end,
                               ys_end) = get_paths(pred_np, landmark_points)

        # ridges to masks
        mask_lvcav, mask_lvwall = np.zeros_like(
            t2w_landmark, dtype=np.uint8), np.zeros_like(t2w_landmark,
                                                         dtype=np.uint8)
        points_end = np.array([list(zip(xs_end, ys_end))])
        points_epi = np.array([list(zip(xs_epi, ys_epi))])
        color = np.uint8(np.ones(3) * 1).tolist()
        cv2.fillPoly(mask_lvcav, points_end, color)
        cv2.fillPoly(mask_lvwall, points_epi, color)

        # sectors
        sectors, sectors_32 = compute_bullseye_sector_mask_for_slice(
            mask_lvcav, mask_lvwall, mask_rvmid, mask_rvi1, 6)
        sectors = resize_for_vis(sectors, vis_res, is_mask=True)
        sector_row = np.concatenate((sectors, sectors, sectors), axis=1)
        img_mask = np.concatenate((sector_row, sector_row), axis=0)

        images.append(img_withoutmask)
        masks.append(img_mask)

        if i >= vis_n - 1:
            break

    # WandB
    wandb_images = []
    for image, mask in zip(images, masks):
        wandb_img = wandb.Image(image,
                                masks={
                                    "prediction": {
                                        "mask_data": mask,
                                        "class_labels": mask_classes,
                                    }
                                })
        wandb_images.append(wandb_img)
    wandb.log({"epoch": epoch, "images": wandb_images})

    return images, masks
コード例 #4
0
    if USE_RDP:
        len_pre_rdp = len(points_end[0]), len(points_epi[0])
        points_end = np.expand_dims(rdp(points_end[0]), 0)
        points_epi = np.expand_dims(rdp(points_epi[0]), 0)
        print(f"{len_pre_rdp} -> {len(points_end[0]), len(points_epi[0])}")

    npy_path = predicted_label_path.split('.npy')[0] + '.npy'

    npy = np.load(npy_path)
    t1w, t2w, pd, t1, t2 = np.transpose(npy, (2, 0, 1))

    t1_pre, t1_post, t2, t1w, t2w, pd, t1_t2 = get_normalized_channel_stack(
        t1, t2, t1w, t2w, pd, data_stack_format='all')
    t1_t2_crop, _top_left = center_crop(pad_if_needed(t1_t2,
                                                      min_height=FOV,
                                                      min_width=FOV),
                                        crop_height=FOV,
                                        crop_width=FOV)

    img_out = []

    for i_channel in (3, 4, 5):
        img = default_cmap(t1_t2_crop[:, :, i_channel])
        cv2.polylines(img, points_end, True, [1, 1, 1])
        cv2.polylines(img, points_epi, True, [1, 1, 1])
        img_out.append(img)

    img_out = (np.hstack(img_out) * 255).astype(np.uint8)
    #plt.imshow(img_out)