Пример #1
0
    def predict(self):

        start = time.time()
        self.net.eval()
        metric_dataloader = DataLoader(
            self.metric_dataset, batch_size=1)  # only support batch size = 1
        os.makedirs(ARGS['prediction_save_folder'], exist_ok=True)
        y_true = []
        y_pred = []
        for items in metric_dataloader:
            images, labels, mask = items['image'], items['label'], items[
                'mask']
            images = images.float()
            print('image shape:', images.size())

            image_patches, big_h, big_w = get_test_patches(
                images, ARGS['crop_size'], ARGS['stride_size'])
            test_patch_dataloader = DataLoader(image_patches,
                                               batch_size=ARGS['batch_size'],
                                               shuffle=False,
                                               drop_last=False)
            test_results = []
            print('Number of batches for testing:', len(test_patch_dataloader))

            for patches in test_patch_dataloader:

                if ARGS['gpu']:
                    patches = patches.cuda()

                with torch.no_grad():
                    result_patches_edge, result_patches = self.net(patches)

                test_results.append(result_patches.cpu())

            test_results = torch.cat(test_results, dim=0)
            # merge
            test_results = recompone_overlap(test_results, ARGS['crop_size'],
                                             ARGS['stride_size'], big_h, big_w)
            test_results = test_results[:, 1, :images.size(2), :images.size(3)]
            y_pred.append(test_results[mask == 1].reshape(-1))
            y_true.append(labels[mask == 1].reshape(-1))

        y_pred = torch.cat(y_pred).numpy()
        y_true = torch.cat(y_true).numpy()
        calc_metrics(y_pred, y_true)
        finish = time.time()

        print('Calculating metric time consumed: {:.2f}s'.format(finish -
                                                                 start))
Пример #2
0
    def predict(self):

        start = time.time()
        self.net.eval()
        test_dataloader = DataLoader(self.test_dataset, batch_size=1) # only support batch size = 1
        os.makedirs(ARGS['prediction_save_folder'], exist_ok=True)
        for items in test_dataloader:
            images, mask, filename = items['image'], items['mask'], items['filename']
            images = images.float()
            mask = mask.long()
            print('image shape:', images.size())

            image_patches, big_h, big_w = get_test_patches(images, ARGS['crop_size'], ARGS['stride_size'])
            test_patch_dataloader = DataLoader(image_patches, batch_size=ARGS['batch_size'], shuffle=False, drop_last=False)
            test_results = []
            print('Number of batches for testing:', len(test_patch_dataloader))

            for patches in test_patch_dataloader:
                
                if ARGS['gpu']:
                    patches = patches.cuda()
                
                with torch.no_grad():
                    result_patches_edge, result_patches = self.net(patches)
                
                test_results.append(result_patches.cpu())           
            
            test_results = torch.cat(test_results, dim=0)
            # merge
            test_results = recompone_overlap(test_results, ARGS['crop_size'], ARGS['stride_size'], big_h, big_w)
            test_results = test_results[:, 1, :images.size(2), :images.size(3)] * mask
            test_results = Image.fromarray(test_results[0].numpy())
            test_results.save(os.path.join(ARGS['prediction_save_folder'], filename[0]))
            print(f'Finish prediction for {filename[0]}')

        finish = time.time()

        print('Predicting time consumed: {:.2f}s'.format(finish - start))
def predict(ACTIVATION='ReLU',
            dropout=0.2,
            minimum_kernel=32,
            epochs=50,
            crop_size=64,
            stride_size=3,
            DATASET='DRIVE'):

    print('-' * 40)
    print('Loading and preprocessing test data...')
    print('-' * 40)

    network_name = "Res-unet"
    model_name = f"{network_name}_cropsize_{crop_size}_epochs_{epochs}"

    prepare_dataset.prepareDataset(DATASET)
    test_data = [
        prepare_dataset.getTestData(0, DATASET),
        prepare_dataset.getTestData(1, DATASET),
        prepare_dataset.getTestData(2, DATASET)
    ]
    IMAGE_SIZE = None
    if DATASET == 'DRIVE':
        IMAGE_SIZE = (565, 584)

    gt_list_out = {}
    pred_list_out = {}
    try:
        os.makedirs(f"./output/{model_name}/crop_size_{crop_size}/out/",
                    exist_ok=True)
        gt_list_out.update({f"out": []})
        pred_list_out.update({f"out": []})
    except:
        pass

    print('-' * 30)
    print('Loading saved weights...')
    print('-' * 30)

    activation = globals()[ACTIVATION]
    model = get_res_unet(minimum_kernel=minimum_kernel,
                         do=dropout,
                         size=crop_size,
                         activation=activation)
    print("Model : %s" % model_name)
    load_path = f"./trained_model/{model_name}/{model_name}.hdf5"
    model.load_weights(load_path, by_name=False)

    imgs = test_data[0]
    segs = test_data[1]
    masks = test_data[2]

    print('-' * 30)
    print('Predicting masks on test data...')
    print('-' * 30)
    print('\n')

    for i in tqdm(range(len(imgs))):

        img = imgs[i]  # (576,576,3)
        seg = segs[i]  # (576,576,1)
        mask = masks[i]  # (584,565,1)

        patches_pred, new_height, new_width, adjustImg = crop_prediction.get_test_patches(
            img, crop_size, stride_size)
        pred = model.predict(patches_pred)  # 预测数据

        pred_patches = crop_prediction.pred_to_patches(pred, crop_size,
                                                       stride_size)
        pred_imgs = crop_prediction.recompone_overlap(pred_patches, crop_size,
                                                      stride_size, new_height,
                                                      new_width)
        pred_imgs = pred_imgs[:, 0:prepare_dataset.DESIRED_DATA_SHAPE[0],
                              0:prepare_dataset.DESIRED_DATA_SHAPE[0], :]
        probResult = pred_imgs[0, :, :, 0]  # (576,576)
        pred_ = probResult
        with open(
                f"./output/{model_name}/crop_size_{crop_size}/out/{i + 1:02}.pickle",
                'wb') as handle:
            pickle.dump(pred_, handle, protocol=pickle.HIGHEST_PROTOCOL)
        pred_ = resize(pred_, IMAGE_SIZE[::-1])  # (584,565)
        mask_ = mask
        mask_ = resize(mask_, IMAGE_SIZE[::-1])  # (584,565)
        seg_ = seg
        seg_ = resize(seg_, IMAGE_SIZE[::-1])  # (584,565)
        gt_ = (seg_ > 0.5).astype(int)
        gt_flat = []
        pred_flat = []
        for p in range(pred_.shape[0]):
            for q in range(pred_.shape[1]):
                if mask_[p, q] > 0.5:  # Inside the mask pixels only
                    gt_flat.append(gt_[p, q])
                    pred_flat.append(pred_[p, q])

        gt_list_out[f"out"] += gt_flat
        pred_list_out[f"out"] += pred_flat

        pred_ = 255. * (pred_ - np.min(pred_)) / (np.max(pred_) -
                                                  np.min(pred_))
        cv2.imwrite(
            f"./output/{model_name}/crop_size_{crop_size}/out/{i + 1:02}.png",
            pred_)

    print('-' * 30)
    print('Prediction finished')
    print('-' * 30)
    print('\n')

    print('-' * 30)
    print('Evaluate the results')
    print('-' * 30)

    evaluate(gt_list_out[f"out"], pred_list_out[f"out"], epochs, crop_size,
             DATASET, network_name)

    print('-' * 30)
    print('Evaluate finished')
    print('-' * 30)
Пример #4
0
def predict(ACTIVATION='ReLU',
            dropout=0.1,
            batch_size=32,
            repeat=4,
            minimum_kernel=32,
            epochs=200,
            iteration=3,
            crop_size=128,
            stride_size=3,
            input_path='',
            output_path='',
            DATASET='ALL'):
    exts = ['png', 'jpg', 'tif', 'bmp', 'gif']

    if not input_path.endswith('/'):
        input_path += '/'
    paths = [
        input_path + i for i in sorted(os.listdir(input_path))
        if i.split('.')[-1] in exts
    ]

    gt_list_out = {}
    pred_list_out = {}

    os.makedirs(f"{output_path}/out_seg/", exist_ok=True)
    os.makedirs(f"{output_path}/out_art/", exist_ok=True)
    os.makedirs(f"{output_path}/out_vei/", exist_ok=True)
    os.makedirs(f"{output_path}/out_final/", exist_ok=True)

    activation = globals()[ACTIVATION]
    model = define_model.get_unet(minimum_kernel=minimum_kernel,
                                  do=dropout,
                                  activation=activation,
                                  iteration=iteration)
    model_name = f"Final_Emer_Iteration_{iteration}_cropsize_{crop_size}_epochs_{epochs}"
    print("Model : %s" % model_name)
    load_path = f"trained_model/{DATASET}/{model_name}.hdf5"
    model.load_weights(load_path, by_name=False)

    for i in tqdm(range(len(paths))):
        filename = '.'.join(paths[i].split('/')[-1].split('.')[:-1])
        img = Image.open(paths[i])
        image_size = img.size
        img = np.array(img) / 255.
        img = resize(img, [576, 576])

        patches_pred, new_height, new_width, adjustImg = crop_prediction.get_test_patches(
            img, crop_size, stride_size)
        preds = model.predict(patches_pred)

        #for segmentation
        pred = preds[iteration]
        pred_patches = crop_prediction.pred_to_patches(pred, crop_size,
                                                       stride_size)
        pred_imgs = crop_prediction.recompone_overlap(pred_patches, crop_size,
                                                      stride_size, new_height,
                                                      new_width)
        pred_imgs = pred_imgs[:, 0:576, 0:576, :]
        probResult = pred_imgs[0, :, :, 0]
        pred_ = probResult
        pred_ = 255. * (pred_ - np.min(pred_)) / (np.max(pred_) -
                                                  np.min(pred_))
        pred_seg = pred_
        pred_ = resize(pred_, image_size[::-1])
        cv2.imwrite(f"{output_path}/out_seg/{filename}.png", pred_)

        #for artery
        pred = preds[2 * iteration + 1]
        pred_patches = crop_prediction.pred_to_patches(pred, crop_size,
                                                       stride_size)
        pred_imgs = crop_prediction.recompone_overlap(pred_patches, crop_size,
                                                      stride_size, new_height,
                                                      new_width)
        pred_imgs = pred_imgs[:, 0:576, 0:576, :]
        probResult = pred_imgs[0, :, :, 0]
        pred_ = probResult
        pred_ = 255. * (pred_ - np.min(pred_)) / (np.max(pred_) -
                                                  np.min(pred_))
        pred_art = pred_
        pred_ = resize(pred_, image_size[::-1])
        cv2.imwrite(f"{output_path}/out_art/{filename}.png", pred_)

        #for vein
        pred = preds[3 * iteration + 2]
        pred_patches = crop_prediction.pred_to_patches(pred, crop_size,
                                                       stride_size)
        pred_imgs = crop_prediction.recompone_overlap(pred_patches, crop_size,
                                                      stride_size, new_height,
                                                      new_width)
        pred_imgs = pred_imgs[:, 0:576, 0:576, :]
        probResult = pred_imgs[0, :, :, 0]
        pred_ = probResult
        pred_ = 255. * (pred_ - np.min(pred_)) / (np.max(pred_) -
                                                  np.min(pred_))
        pred_vei = pred_
        pred_ = resize(pred_, image_size[::-1])
        cv2.imwrite(f"{output_path}/out_vei/{filename}.png", pred_)

        #for final
        pred_final = np.zeros((*list(pred_seg.shape), 3), dtype=pred_seg.dtype)
        art_temp = pred_final[pred_art >= pred_vei]
        art_temp[:, 2] = pred_seg[pred_art >= pred_vei]
        pred_final[pred_art >= pred_vei] = art_temp
        vei_temp = pred_final[pred_art < pred_vei]
        vei_temp[:, 0] = pred_seg[pred_art < pred_vei]
        pred_final[pred_art < pred_vei] = vei_temp
        pred_ = pred_final
        pred_ = resize(pred_, image_size[::-1])
        cv2.imwrite(f"{output_path}/out_final/{filename}.png", pred_)
Пример #5
0
def predict(ACTIVATION='ReLU',
            dropout=0.1,
            batch_size=32,
            repeat=4,
            minimum_kernel=32,
            epochs=200,
            iteration=3,
            crop_size=128,
            stride_size=3,
            DATASET='DRIVE'):
    prepare_dataset.prepareDataset(DATASET)
    test_data = [
        prepare_dataset.getTestData(0, DATASET),
        prepare_dataset.getTestData(1, DATASET),
        prepare_dataset.getTestData(2, DATASET)
    ]

    IMAGE_SIZE = None
    if DATASET == 'DRIVE':
        IMAGE_SIZE = (565, 584)
    elif DATASET == 'CHASEDB1':
        IMAGE_SIZE = (999, 960)
    elif DATASET == 'STARE':
        IMAGE_SIZE = (700, 605)

    gt_list_out = {}
    pred_list_out = {}
    for out_id in range(iteration + 1):
        try:
            os.makedirs(
                f"./output/{DATASET}/crop_size_{crop_size}/out{out_id + 1}/",
                exist_ok=True)
            gt_list_out.update({f"out{out_id + 1}": []})
            pred_list_out.update({f"out{out_id + 1}": []})
        except:
            pass

    activation = globals()[ACTIVATION]
    model = define_model.get_unet(minimum_kernel=minimum_kernel,
                                  do=dropout,
                                  activation=activation,
                                  iteration=iteration)
    model_name = f"Final_Emer_Iteration_{iteration}_cropsize_{crop_size}_epochs_{epochs}"
    print("Model : %s" % model_name)
    load_path = f"trained_model/{DATASET}/{model_name}.hdf5"
    model.load_weights(load_path, by_name=False)

    imgs = test_data[0]
    segs = test_data[1]
    masks = test_data[2]

    for i in tqdm(range(len(imgs))):

        img = imgs[i]
        seg = segs[i]
        if masks:
            mask = masks[i]

        patches_pred, new_height, new_width, adjustImg = crop_prediction.get_test_patches(
            img, crop_size, stride_size)
        preds = model.predict(patches_pred)

        out_id = 0
        for pred in preds:
            pred_patches = crop_prediction.pred_to_patches(
                pred, crop_size, stride_size)
            pred_imgs = crop_prediction.recompone_overlap(
                pred_patches, crop_size, stride_size, new_height, new_width)
            pred_imgs = pred_imgs[:, 0:prepare_dataset.DESIRED_DATA_SHAPE[0],
                                  0:prepare_dataset.DESIRED_DATA_SHAPE[0], :]
            probResult = pred_imgs[0, :, :, 0]
            pred_ = probResult
            with open(
                    f"./output/{DATASET}/crop_size_{crop_size}/out{out_id + 1}/{i + 1:02}.pickle",
                    'wb') as handle:
                pickle.dump(pred_, handle, protocol=pickle.HIGHEST_PROTOCOL)
            pred_ = resize(pred_, IMAGE_SIZE[::-1])
            if masks:
                mask_ = mask
                mask_ = resize(mask_, IMAGE_SIZE[::-1])
            seg_ = seg
            seg_ = resize(seg_, IMAGE_SIZE[::-1])
            gt_ = (seg_ > 0.5).astype(int)
            gt_flat = []
            pred_flat = []
            for p in range(pred_.shape[0]):
                for q in range(pred_.shape[1]):
                    if not masks or mask_[
                            p, q] > 0.5:  # Inside the mask pixels only
                        gt_flat.append(gt_[p, q])
                        pred_flat.append(pred_[p, q])

            gt_list_out[f"out{out_id + 1}"] += gt_flat
            pred_list_out[f"out{out_id + 1}"] += pred_flat

            pred_ = 255. * (pred_ - np.min(pred_)) / (np.max(pred_) -
                                                      np.min(pred_))
            cv2.imwrite(
                f"./output/{DATASET}/crop_size_{crop_size}/out{out_id + 1}/{i + 1:02}.png",
                pred_)
            out_id += 1

    for out_id in range(iteration + 1)[-1:]:
        print('\n\n', f"out{out_id + 1}")
        evaluate(gt_list_out[f"out{out_id + 1}"],
                 pred_list_out[f"out{out_id + 1}"], DATASET)