Ejemplo n.º 1
0
def gen_mask_by_seg(image=image, size_seg=512, flag_use_dn=True):

    # flip sign of the dark field images
    if flag_use_dn:
        image = image_dn(image)

    image_segs = utils.img_split(image, size_seg=size_seg, overlap=0.4)

    # split, cmpute mask, stitch
    mask_segs = {}
    scores_segs = {}
    for loc_seg, image_seg in image_segs.items():
        image_detect_res = model.detect([image_seg], verbose=0)[0]
        masks_seg = image_detect_res['masks']
        scores_seg = image_detect_res['scores']
        if masks_seg.shape[:2] != image_seg.shape[:2]:  # special case if no mask
            masks_seg = np.zeros(shape=image_seg.shape[:2] + (0, ),
                                 dtype='int')
            scores_seg = np.zeros(shape=(0, ))
        mask_segs[loc_seg] = masks_seg
        scores_segs[loc_seg] = scores_seg
    mask_stitched, _, score_stiched = utils.img_stitch(
        mask_segs, mode='mask', info_mask_dict=scores_segs)

    if mask_stitched.shape[2] != 0:
        mask_size = np.average(np.sqrt(np.sum(mask_stitched, axis=(0, 1))),
                               weights=score_stiched)
    else:
        mask_size = 16.0

    return mask_stitched, score_stiched, mask_size
def gen_mask_by_seg(image=image, size_seg=256, flag_use_dn=False):

    # flip sign of the dark field images
    if flag_use_dn:
        image = image_dn(image)

    image_segs = utils.img_split(image, size_seg=size_seg, overlap=0.2)

    # split, cmpute mask, stitch
    mask_segs = {}
    scores_segs = {}
    for loc_seg, image_seg in image_segs.items():
        image_detect_res = model.detect([image_seg], verbose=0)[0]
        masks_seg = image_detect_res['masks']
        scores_seg = image_detect_res['scores']
        if masks_seg.shape[:2] != image_seg.shape[:2]:  # special case if no mask
            masks_seg = np.zeros(shape=image_seg.shape[:2] + (0, ),
                                 dtype='int')
            scores_seg = np.zeros(shape=(0, ))
        mask_segs[loc_seg] = masks_seg
        scores_segs[loc_seg] = scores_seg
    mask_stitched, _, score_stiched = utils.img_stitch(
        mask_segs, mode='mask', info_mask_dict=scores_segs)

    # post processing
    if False:  # filter out mask that is not darker than average
        pass

    mask_size = np.average(np.sqrt(np.sum(mask_stitched, axis=(0, 1))),
                           weights=score_stiched)
    return mask_stitched, score_stiched, mask_size
Ejemplo n.º 3
0
def process_data(data_train,
                 yn_flip_dark_field=True,
                 yn_augment=True,
                 yn_split=True,
                 yn_split_balance=True,
                 max_num_seg=12):
    data_train_after = copy.deepcopy(data_train)

    data_train_after = add_class_to_data_dict(data_train_after)

    # flip dark field
    if yn_flip_dark_field:
        for id_img in data_train_after:
            if data_train_after[id_img]['class'] == 'dark':
                data_train_after[id_img][
                    'image'] = 255 - data_train_after[id_img]['image']

    # data augmentation by rotation
    if yn_augment:
        data_train_before, data_train_after = data_train_after, {}
        for id_img in data_train_before:
            image_cur = data_train_before[id_img]['image']
            mask_cur = data_train_before[id_img]['mask']
            if 'class' in data_train_before[id_img]:
                class_cur = data_train_before[id_img]['class']

            for i in range(4):
                id_img_new = id_img + '_rot{}'.format(i)
                data_train_after[id_img_new] = {}
                data_train_after[id_img_new]['image'] = np.rot90(image_cur, i)
                data_train_after[id_img_new]['mask'] = np.rot90(mask_cur, i)
                if 'class' in data_train_before[id_img]:
                    data_train_after[id_img_new]['class'] = class_cur

    # data split based on nuclei size
    temp = {}

    if yn_split:
        data_train_before, data_train_after = data_train_after, {}
        amplification_ideal = 12  # ideal size_patch/size_nuclei
        min_size_seg = 64
        for id_img in tqdm(data_train_before.keys()):
            img_cur = data_train_before[id_img]['image']
            mask_cur = data_train_before[id_img]['mask']
            if 'class' in data_train_before[id_img]:
                class_cur = data_train_before[id_img]['class']

            size_nuclei = np.sqrt(
                np.sum(mask_cur > 0) * 1.0 / (len((np.unique(mask_cur))) + 1))
            size_nuclei = max(size_nuclei, 8)
            size_seg_opi = size_nuclei * amplification_ideal
            size_seg_0 = utils.floor_pow2(size_seg_opi)
            size_seg_0 = max(min_size_seg, size_seg_0)
            list_size_seg = [size_seg_0, size_seg_0 * 2]
            for size_seg in list_size_seg:
                img_cur_seg = utils.img_split(img=img_cur, size_seg=size_seg)
                mask_cur_seg = utils.img_split(img=mask_cur, size_seg=size_seg)
                for start_loc in img_cur_seg.keys():
                    data_train_after[(id_img, start_loc, size_seg)] = {}
                    data_train_after[(
                        id_img, start_loc,
                        size_seg)]['image'] = img_cur_seg[start_loc]
                    data_train_after[(
                        id_img, start_loc,
                        size_seg)]['mask'] = mask_cur_seg[start_loc]
                if 'class' in data_train_before[id_img]:
                    data_train_after[(id_img, start_loc,
                                      size_seg)]['class'] = class_cur

    # balance dataset
    if yn_split and yn_split_balance:
        data_train_before, data_train_after = data_train_after, {}
        data_train_after = {}
        seg_ids = list(data_train_before.keys())
        dict_img_seg = {}
        for seg_id in seg_ids:
            img_id = seg_id[0]
            if img_id not in dict_img_seg:
                dict_img_seg[img_id] = []
            else:
                dict_img_seg[img_id].append(seg_id)
        # plt.subplot(1,2,1)
        # plt.hist([len(value) for key, value in dict_img_seg.items()], 20)

        for seg_id in dict_img_seg:
            num_split = len(dict_img_seg[seg_id])
            if num_split > max_num_seg:  # if too many, select
                for id_new in random.choices(dict_img_seg[seg_id],
                                             k=max_num_seg):
                    data_train_after[id_new +
                                     (0, )] = data_train_before[id_new]
            else:
                for i in range(max_num_seg // num_split):
                    for id_new in dict_img_seg[seg_id]:
                        data_train_after[id_new +
                                         (i, )] = data_train_before[id_new]

    data_train_processed = data_train_after

    return data_train_processed
def main():
    parser = argparse.ArgumentParser(description='Prediction')
    parser.add_argument('--bs', default=10, type=int, help='Batch size = image width/320')
    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('Device is {}!'.format(device))

    # Hyperparameters
    batch_size = args.bs
    image_size = 320
    print('Batch size: {}'.format(batch_size))

    # Model
    print('==> Building model..')
    net = DeConvNet()
    net = net.to(device)

    # Enabling cudnn, which may lead to about 2 GB extra memory
    if device == 'cuda':
        cudnn.benchmark = True
        print('cudnn benchmark enabled!')

    # Load checkpoint
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/training_saved.t7')  # Load your saved model
    net.load_state_dict(checkpoint['net'])

    # Tranformation
    img_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.6672,  0.5865,  0.5985), (1.0, 1.0, 1.0)),
    ])

    X_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.transforms.Pad((160, 0, 160, 0), fill=(0, 0, 0), padding_mode='constant'),
    ])

    Y_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.transforms.Pad((0, 160, 0, 160), fill=(0, 0, 0), padding_mode='constant'),
    ])

    XY_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.transforms.Pad((160, 160, 160, 160), fill=(0, 0, 0), padding_mode='constant'),
    ])

    # Prediction
    print('==> Prediction begins..')
    net.eval()
    with torch.no_grad():
        for photo_folder in os.listdir('./photo/'):
            img_files = [os.path.join(photo_folder, file) for file in os.listdir(
                os.path.join('./photo/', photo_folder)) if file.endswith("ORIG.tif")]
            for img_file in img_files:
                start_time = time.time()
                
                print('{}:'.format(img_file))
                img = plt.imread(os.path.join('./photo/', img_file))

                simgs = img_split(img, cut_size=image_size)
                simgs_X = img_split(np.asarray(X_transform(img)), cut_size=image_size)
                simgs_Y = img_split(np.asarray(Y_transform(img)), cut_size=image_size)
                simgs_XY = img_split(np.asarray(XY_transform(img)), cut_size=image_size)

                mask = predict(device, net, img_transform, simgs, overlap_mode=0, batch_size=batch_size, image_size=image_size).astype(bool)
                mask_X = predict(device, net, img_transform, simgs_X, overlap_mode=1, batch_size=batch_size+1, image_size=image_size).astype(bool)
                mask_Y = predict(device, net, img_transform, simgs_Y, overlap_mode=2, batch_size=batch_size, image_size=image_size).astype(bool)
                mask_XY = predict(device, net, img_transform, simgs_XY, overlap_mode=3, batch_size=batch_size+1, image_size=image_size).astype(bool)

                mask_combined = (mask+mask_X+mask_Y+mask_XY).astype(np.uint8)
                imsave('./photo/{}_PRED.tif'.format(img_file[:-9]), (1-mask_combined)*255)
                print('The mask of {} was predicted and saved!'.format(img_file[:-9]))
                
                print("--- %s seconds ---" % (time.time() - start_time))
            print('{} complete!\n'.format(photo_folder))
    data_train_seg = {}
    # split every image so that the diameter of every nuclei takes 1/16 ~ 1/8 of the image length
    # list_amplification = [8, 16]
    list_amplification = [16, 32]
    for id_img in tqdm(data_train_selection.keys()):
        img_cur = data_train_selection[id_img]['image']
        mask_cur = data_train_selection[id_img]['mask']
        size_nuclei = int(
            np.mean(
                np.sqrt(
                    np.sum(utils.mask_2Dto3D(
                        data_train_selection[id_img]['mask']),
                           axis=(0, 1)))))
        for amplification in list_amplification:
            img_cur_seg = utils.img_split(img=img_cur,
                                          size_seg=size_nuclei * amplification)
            mask_cur_seg = utils.img_split(img=mask_cur,
                                           size_seg=size_nuclei *
                                           amplification)
            for start_loc in img_cur_seg.keys():
                data_train_seg[(id_img, start_loc, amplification)] = {}
                data_train_seg[(
                    id_img, start_loc,
                    amplification)]['image'] = img_cur_seg[start_loc]
                data_train_seg[(
                    id_img, start_loc,
                    amplification)]['mask'] = mask_cur_seg[start_loc]

    if yn_dark_nuclei:
        with open('./data/data_train_dn_seg.pickle', 'wb') as f:
            pickle.dump(data_train_seg, f)
Ejemplo n.º 6
0
print(score)


""" ========== 4. image split and stitch ========== """

# ----- 4.1 image split
id_eg = np.random.choice(list(data_tr.keys()))
image = data_tr[id_eg]['image']
mask_true = data_tr[id_eg]['mask']

# full image
plt.figure()
utils.plot_img_and_mask_from_dict(data_tr, id_eg)

# split image segments
img_seg = utils.img_split(image, size_seg=128, overlap=0.2)
img_seg_start = np.array(list(img_seg.keys()))
rs_start = np.unique(img_seg_start[:, 0])
cs_start = np.unique(img_seg_start[:, 1])
h_fig, h_axes = plt.subplots(len(rs_start), len(cs_start))
for i_r, r in enumerate(rs_start):
    for i_c, c in enumerate(cs_start):
        plt.axes(h_axes[i_r, i_c])
        plt.imshow(img_seg[(r, c)])
        plt.axis('off')
        plt.title((r, c), fontsize='x-small')

mask_seg = utils.img_split(mask_true, size_seg=128, overlap=0.5)
mask_seg = {key: utils.mask_2Dto3D(mask_seg[key]) for key in mask_seg}