예제 #1
0
    def __getitem__(self, index):
        data_path = self.meta[index]
        images = np.load(data_path)

        # CT clip
        num_frames = len(images)
        left, right = int(num_frames * self.clip_range[0]), int(
            num_frames * self.clip_range[1])
        images = images[left:right]

        num_frames = len(images)
        shape = images.shape
        #h, w = shape[1:]

        if False:
            from zqlib import imgs2vid
            imgs2vid(np.concatenate([images, masks * 255], axis=2), "test.avi")
            import pdb
            pdb.set_trace()

        # To Tensor and Resize
        images = np.asarray(images, dtype=np.float32)
        images = images / 255.

        images = np.expand_dims(images,
                                axis=1)  # Bx1xHxW, add channel dimension

        #info = {"name": data_path, "num_frames": num_frames, "shape": shape, "pad": ((lh,uh),(lw,uw))}
        info = {"name": data_path, "num_frames": num_frames, "shape": shape}

        th_img = torch.from_numpy(images.copy()).float()
        th_label = torch.zeros_like(th_img)

        return th_img, th_label, info
예제 #2
0
def combine_ct_images_and_masks(ct_images, ct_masks):
    num_frames = len(ct_images)
    ct_images = np.asarray(ct_images, dtype=np.float32)
    ct_images = ct_images / 255.

    if False:
        from zqlib import imgs2vid
        imgs2vid(np.concatenate([ct_images*255, ct_masks*255], axis=2), "combine_image_mask.avi")

    ct_imasks = np.concatenate(
                    [ct_images[None, ...], ct_masks[None, ...]], 
                    axis=0)
    th_imasks = torch.unsqueeze(
                    torch.from_numpy(ct_imasks.copy()), 0).float()
    return th_imasks
예제 #3
0
def load_processed_ct_images(npy_filepath, patient, clip_range):
    files = os.listdir(npy_filepath + patient)
    files.sort()
    images = []
    for file in files:
        images.append(
            cv2.imread(npy_filepath + patient + "/" + file,
                       cv2.COLOR_BGR2GRAY))
    images = np.array(images)
    num_frames = len(images)
    global left
    global origin_images
    origin_images = images.copy()

    left, right = int(num_frames * clip_range[0]), int(num_frames *
                                                       clip_range[1])
    images = images[left:right]

    num_frames = len(images)
    shape = images.shape

    if False:
        from zqlib import imgs2vid
        imgs2vid(images, "test_image.avi")

    images = np.asarray(images, dtype=np.float32)
    images = images / 255.
    # add channel dimension
    images = np.expand_dims(images, axis=1)
    img_info = {
        "name": npy_filepath,
        "num_frames": num_frames,
        "clip_range": clip_range,
        "shape": shape
    }
    th_images = torch.from_numpy(images.copy()).float()
    return th_images, img_info
예제 #4
0
def load_processed_ct_images(npy_filepath, clip_range):
    images = np.load(npy_filepath)
    num_frames = len(images)

    left, right = int(num_frames*clip_range[0]), int(num_frames*clip_range[1])
    images = images[left:right]

    num_frames = len(images)
    shape = images.shape

    if False:
        from zqlib import imgs2vid
        imgs2vid(images, "test_image.avi")
    
    images = np.asarray(images, dtype=np.float32)
    images = images / 255.
    # add channel dimension
    images = np.expand_dims(images, axis=1)
    img_info = {"name": npy_filepath, 
                "num_frames": num_frames,
                "clip_range": clip_range,
                "shape": shape}
    th_images = torch.from_numpy(images.copy()).float()
    return th_images, img_info
예제 #5
0
    mask_fn = "/your/mask_npy/path/" + patient_id + ".npy"
    mask = np.load(mask_fn)

    T, H, W = mask.shape
    final_mask = np.zeros(mask.shape)
    mask = mask[int(T * 0.1):int(T * 0.8)]

    win_size = 6
    bias_term = 0.05

    ref_std = 0.5
    res_list = []
    for m in mask:
        #3DCC
        res = sub_seg(m, win_size)
        res_list.append(res)
    res_list = np.asarray(res_list)
    labels = measure.label(res_list)
    # find the top 5 3D connected components
    l_max = top5_label_volume(labels, bg=res_list[0, 0, 0])
    select_ = np.zeros(labels.shape)
    # save the top 5 3D connected components
    for l_i in [l_max[-1]]:
        select_ = select_ + (labels == l_i)
    final_mask[int(T * 0.1):int(T * 0.8)] = select_
    final_mask = final_mask.astype(np.uint8)
    np.save("binary-seg/" + patient_id + ".npy", final_mask)
    # transform images to video
    imgs2vid(np.concatenate([mask * 255, select_ * 255], axis=2),
             "./results/" + patient_id.split('.')[0] + ".avi")
예제 #6
0
os.makedirs("visual", exist_ok=True)

with torch.no_grad():
    for i, (all_F, all_M, all_info) in enumerate(ValidLoader):
        logger.info (all_info)
        all_E = []
        images = all_F.cuda()
        #(lh, uh), (lw, uw) = all_info[0]["pad"]
        num = len(images)

        for ii in range(num):
            image = images[ii:ii+1]
            pred = model(image)
            pred = torch.argmax(F.softmax(pred, dim=1), dim=1)
            all_E.append(pred)

        all_E = torch.cat(all_E, dim=0).cpu().numpy().astype('uint8')
        all_OF = np.uint8(all_F[:, 0, :, :].cpu().numpy().astype('float32') * 255)

        unique_id = all_info[0]["name"].split('/')[-1].replace('.npy', '')
        np.save("{}/{}.npy".format(RESULE_HOME, unique_id), all_OF)
        np.save("{}/{}-dlmask.npy".format(RESULE_HOME, unique_id), all_E)

        if False:
            from zqlib import imgs2vid
            imgs2vid(np.concatenate([all_OF, all_E*255], axis=2), "visual/{}.avi".format(unique_id))
        #import pdb
        #pdb.set_trace()


예제 #7
0
def cam_mask(ct_images, ct_masks, model, filepath, uniq_id):
    global origin_images
    global crop_box
    global zz, yy, xx
    global left
    global supplement_info, ct_zoom_images
    from zqlib import imgs2vid
    ct_immasks = combine_ct_images_and_masks(ct_images, ct_masks)
    preds, features = model(ct_immasks.cuda())
    # caculate cam
    pool = torch.nn.AdaptiveAvgPool3d(output_size=(1, 1, 1))
    weight1 = pool(model.module.head[1].weight).squeeze()
    weight2 = pool(model.module.head[4].weight).squeeze()
    weight3 = pool(model.module.head[8].weight).squeeze()
    weight4 = model.module.classifier[0].weight
    weight5 = model.module.classifier[1].weight
    weight = torch.mm(weight1.t(), weight2.t())
    weight = torch.mm(weight, weight3.t())
    weight = torch.mm(weight, weight4.t())
    weight = torch.mm(weight, weight5.t())
    final_weight = weight[:, 1].unsqueeze(0)
    features = features.squeeze()
    channel, t, h, w = features.shape
    features = features.reshape(channel, t * h * w)
    cam = torch.mm(final_weight, features).reshape(t, h, w)
    cam = cam.detach().cpu().numpy()
    ori_h, ori_w = crop_box[3] - crop_box[1], crop_box[2] - crop_box[0]

    from scipy.ndimage import zoom
    import cv2
    slice, H, W = ct_images.shape
    cam = zoom(cam, (slice / t, H / h, W / w), order=1)
    cam = (cam > 0) * cam
    cam = cam - np.min(cam)
    cam = cam / np.max(cam)

    num, thresholds = np.histogram(cam, bins=50)
    temp = 0
    for x in range(len(thresholds)):
        temp = temp + num[len(thresholds) - 2 - x]
        if temp > cam.shape[0] * cam.shape[1] * cam.shape[2] * 0.05:
            threshold = thresholds[len(thresholds) - 1 - x]
            break
    cam = np.round(cam * 255).astype(np.uint8)
    fisrt_cam = cam.copy()
    cam = np.uint8(cam > threshold * 255)
    cam = np.where(ct_zoom_images > 75, cam, 0)

    global mask_left, mask_right
    import cc3d
    from skimage import measure
    labels_out = measure.label(cam[mask_left:-mask_right])
    vals, counts = np.unique(labels_out, return_counts=True)
    ids_ = counts.argsort()[-10:-1]
    extracted_images = []
    connected_comp_areas = []
    for segid in ids_:
        extracted_image = labels_out * (
            labels_out == vals[segid]) * supplement_info[mask_left:-mask_right]
        connected_comp_areas.append((extracted_image).sum())
        extracted_images.append(extracted_image)
    temp = np.zeros(cam.shape)
    connected_comp_ids = np.argsort(connected_comp_areas)[-1]

    temp[mask_left:-mask_right] = temp[
        mask_left:-mask_right] + extracted_images[connected_comp_ids]
    cam = temp
    del extracted_images, labels_out

    cam = zoom(cam, (1, ori_h / H, ori_w / W), order=0)
    ori_cam = np.zeros((origin_images.shape[0], 512, 512), dtype=np.uint8)
    ori_cam[left + zz.min():left + zz.max(), crop_box[1]:crop_box[3],
            crop_box[0]:crop_box[2]] = cam
    cam = ori_cam.astype(np.uint8)

    origin_images = origin_images[:, :, :, np.newaxis]
    origin_images = np.concatenate(
        (origin_images, origin_images, origin_images), axis=3)
    origin_images[:, :, :, 2] = np.where(cam > 0, 255 * np.ones(cam.shape),
                                         origin_images[:, :, :, 2])
    origin_images = np.clip(origin_images, a_min=0, a_max=255)
    origin_images = np.uint8(origin_images)
    from zqlib import imgs2vid
    imgs2vid(origin_images, "visual-cam/" + uniq_id + ".avi")
    np.save("cam-mask/" + uniq_id + ".npy", cam)
#                     cropbox[1, 0]:cropbox[1, 1],
#                     cropbox[2, 0]:cropbox[2, 1]]
#
#    crop_masks = masks[cropbox[0, 0]:cropbox[0, 1],
#                       cropbox[1, 0]:cropbox[1, 1],
#                       cropbox[2, 0]:cropbox[2, 1]]
#
#    immasks = np.concatenate([crop_imgs, crop_masks*255], axis=2)
#    imgs2vid(immasks, f"visual/{unique_id}.avi")

for unique_id in normal_ids:
    data_type = "normal"
    print(unique_id)
    imgs = np.load(os.path.join(f"{data_type}", f"{unique_id}.npy"))
    masks = np.load(os.path.join(f"{data_type}", f"{unique_id}_lung_mask.npy"))

    zz, yy, xx = np.where(masks)
    cropbox = np.array([[np.min(zz), np.max(zz)], [np.min(yy),
                                                   np.max(yy)],
                        [np.min(xx), np.max(xx)]])

    crop_imgs = imgs[cropbox[0, 0]:cropbox[0, 1], cropbox[1, 0]:cropbox[1, 1],
                     cropbox[2, 0]:cropbox[2, 1]]

    crop_masks = masks[cropbox[0, 0]:cropbox[0, 1],
                       cropbox[1, 0]:cropbox[1, 1], cropbox[2, 0]:cropbox[2,
                                                                          1]]

    immasks = np.concatenate([crop_imgs, crop_masks * 255], axis=2)
    imgs2vid(immasks, f"visual/{unique_id}.avi")
    def __getitem__(self, index):
        data_path = self.meta[index]

        if self.split == 'test':
            mask_path = data_path
        else:
            mask_path = data_path.replace('.npy', '_lung_mask.npy')

        images = np.load(data_path)
        masks = np.uint8(np.load(mask_path) > 0)

        # CT clip
        num_frames = len(images)
        left, right = int(num_frames * self.clip_range[0]), int(
            num_frames * self.clip_range[1])
        images = images[left:right]
        masks = masks[left:right]

        # Random sample
        if self.sample_number > -1:
            num_frames = len(images)
            rand_index = np.random.choice([*range(0, num_frames)],
                                          self.sample_number,
                                          replace=False)
            images = images[rand_index]
            masks = masks[rand_index]

        num_frames = len(images)
        shape = images.shape
        #h, w = shape[1:]

        # Make it dividable by 16
        #new_h = h + 16 - h % 16
        #new_w = w + 16 - w % 16
        #lh, uh = (new_h-h) / 2, (new_h-h) / 2 + (new_h-h) % 2
        #lw, uw = (new_w-w) / 2, (new_w-w) / 2 + (new_w-w) % 2
        #lh, uh, lw, uw = int(lh), int(uh), int(lw), int(uw)
        #images = np.pad(images, ((0,0),(lh,uh),(lw,uw)), mode='constant')
        #masks  = np.pad(masks, ((0,0),(lh,uh),(lw,uw)), mode='constant')

        if False:
            from zqlib import imgs2vid
            imgs2vid(np.concatenate([images, masks * 255], axis=2), "test.avi")
            import pdb
            pdb.set_trace()

        # Data augmentation
        if self.split == "train":
            images, masks = Rand_Transforms(
                images,
                masks,
                #ANGLE_R=10, TRANS_R=0.1, SCALE_R=0.2, SHEAR_R=10,
                ANGLE_R=0,
                TRANS_R=0,
                SCALE_R=0,
                SHEAR_R=0,
                BRIGHT_R=0.5,
                CONTRAST_R=0.3)

        # To Tensor and Resize
        images = np.asarray(images, dtype=np.float32)
        images = images / 255.

        images = np.expand_dims(images,
                                axis=1)  # Bx1xHxW, add channel dimension
        masks = masks

        #info = {"name": data_path, "num_frames": num_frames, "shape": shape, "pad": ((lh,uh),(lw,uw))}
        info = {"name": data_path, "num_frames": num_frames, "shape": shape}

        th_img = torch.from_numpy(images.copy()).float()
        th_label = torch.from_numpy(masks.copy()).long()

        if self.split == 'test':
            th_label = torch.zeros_like(th_label)

        return th_img, th_label, info