Exemple #1
0
 def __init__(self, output_path, model_name, img_shape, nb_class):
     self.epoch = 0
     self.output_path = output_path
     self.model_name = model_name
     self.img_shape = img_shape
     self.nb_class = nb_class
     self.palette = VOCPalette(nb_class=nb_class)
Exemple #2
0
class TrainCheck(Callback):
    def __init__(self, output_path, model_name, img_shape, nb_class):
        self.epoch = 0
        self.output_path = output_path
        self.model_name = model_name
        self.img_shape = img_shape
        self.nb_class = nb_class
        self.palette = VOCPalette(nb_class=nb_class)

    def result_map_to_img(self, res_map):
        res_map = np.squeeze(res_map)
        argmax_idx = np.argmax(res_map, axis=2).astype('uint8')

        return argmax_idx

    def on_epoch_end(self, epoch, logs={}):
        self.epoch = epoch + 1
        # self.visualize(os.path.join(self.output_path,'test.png'))

    def visualize(self, path):
        imgorg = Image.open(path).convert('RGB')
        img = imgorg.resize((self.img_shape[1], self.img_shape[0]),
                            Image.ANTIALIAS)
        img_arr = np.array(img)
        img_arr = img_arr / 127.5 - 1
        img_arr = np.expand_dims(img_arr, 0)
        pred = self.model.predict(img_arr)
        res_img = self.result_map_to_img(pred[0])

        PIL_img_pal = self.palette.genlabelpal(res_img)
        PIL_img_pal = PIL_img_pal.resize((imgorg.size[0], imgorg.size[1]),
                                         Image.ANTIALIAS)
        PIL_img_pal.save(
            os.path.join(
                self.output_path,
                self.model_name + '_epoch_' + str(self.epoch) + '.png'))
Exemple #3
0
def test_seg_model_handler():
    class_num = ['Malignant', 'Benign']
    seg_img_width = 256
    seg_img_height = 256
    seg_loss_weight = 0.5
    cls_loss_weight = 0.5
    nb_class = 2
    image_shape = (seg_img_height, seg_img_width, 3)

    # load saved model
    model_name = './log/unet_sobel_model_weight.h5'
    pair_model = unet(input_shape=(seg_img_height, seg_img_width, 3), num_classes=2,
                 lr_init=1e-4, lr_decay=5e-4, vgg_weight_path=None)
    try:
        pair_model.load_weights(model_name)
    except:
        print("You must train model and get weight before test.")

    # # paths to validation set
    # img_path = '../BUS/data2/original/Case27.png'
    # labe_path = '../BUS/data2/GT/Case27.png'
    #
    # palette = VOCPalette(nb_class=nb_class)
    # imgorg = Image.open(img_path)
    # imglab = Image.open(labe_path)
    # imgorg = imgorg.convert('RGB')
    # img_cls = imgorg.resize((cls_img_width, cls_img_height), Image.ANTIALIAS)
    # img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS)
    # img_cls_arr = np.array(img_cls)
    # img_cls_arr = img_cls_arr / 127.5 - 1
    # img_seg_arr = np.array(img_seg)
    # img_seg_arr = img_seg_arr / 127.5 - 1
    # img_cls_arr = np.expand_dims(img_cls_arr, 0)
    # img_seg_arr = np.expand_dims(img_seg_arr, 0)
    # # predict results
    # pred = pair_model.predict([img_cls_arr, img_seg_arr])
    # cls_r = cls_result(pred[0])
    # seg_r = seg_result(pred[1])
    # # plot the predicted results.
    # PIL_img_pal = palette.genlabelpal(seg_r)
    # PIL_img_pal = PIL_img_pal.resize((imgorg.size[0], imgorg.size[1]), Image.ANTIALIAS)
    # plt.ion()
    # plt.figure('Multi task')
    # plt.suptitle(img_path + '\n' + 'Class:' + class_num[cls_r])
    # plt.subplot(1, 3, 1), plt.title('org')
    # plt.imshow(imgorg), plt.axis('off')
    # plt.subplot(1, 3, 2), plt.title('segmentation result')
    # plt.imshow(PIL_img_pal), plt.axis('off')
    # plt.subplot(1, 3, 3), plt.title('label')
    # plt.imshow(imglab), plt.axis('off')
    # plt.show()

    # collect validation data
    data1 = '../BUS/data1/'
    data2 = '../BUS/data2/'
    validation_data_dir_data1_good = data1 + 'data1_good_val.txt'
    validation_data_dir_data1_bad = data1 + 'data1_bad_val.txt'
    validation_data_dir_data2_good = data2 + 'data2_good_val.txt'
    validation_data_dir_data2_bad = data2 + 'data2_bad_val.txt'

    # Add dataset3
    seg_data = '../BUS/data3/'
    seg_test = seg_data + 'test.txt'

    # Add dataset 3
    # Segment dataset
    with open(seg_test, "r") as f:
        ls = f.readlines()
    seg_test_name = [l.rstrip('\n') for l in ls]

    seg_test_add = []
    seg_test_gt_add = []
    for i in range(len(seg_test_name)):
        seg_test_add.append('../BUS/data3/original/c' + seg_test_name[i][1:] + '.png')
        seg_test_gt_add.append('../BUS/data3/GT/c' + seg_test_name[i][1:] + '_GT.png')
    a = 1

    # data1 good
    with open(validation_data_dir_data1_good, "r") as f:
        ls = f.readlines()
    data1_good_val = [l.rstrip('\n') for l in ls]
    data1_good_val_original = []
    data1_good_val_seggt = []
    for i in range(len(data1_good_val)):
        data1_good_val_original.append('../BUS/data1/original/' + data1_good_val[i] + '.png')
        data1_good_val_seggt.append('../BUS/data1/GT/' + data1_good_val[i] + '.png')
    # data1 bad
    with open(validation_data_dir_data1_bad, "r") as f:
        ls = f.readlines()
    data1_bad_val = [l.rstrip('\n') for l in ls]
    data1_bad_val_original = []
    data1_bad_val_seggt = []
    for i in range(len(data1_bad_val)):
        data1_bad_val_original.append('../BUS/data1/original/' + data1_bad_val[i] + '.png')
        data1_bad_val_seggt.append('../BUS/data1/GT/' + data1_bad_val[i] + '.png')

    # data2 good
    with open(validation_data_dir_data2_good, "r") as f:
        ls = f.readlines()
    data2_good_val = [l.rstrip('\n') for l in ls]
    data2_good_val_original = []
    data2_good_val_seggt = []
    for i in range(len(data2_good_val)):
        data2_good_val_original.append('../BUS/data2/original/C' + data2_good_val[i][1:] + '.png')
        data2_good_val_seggt.append('../BUS/data2/GT/C' + data2_good_val[i][1:] + '.png')

    # data2 bad
    with open(validation_data_dir_data2_bad, "r") as f:
        ls = f.readlines()
    data2_bad_val = [l.rstrip('\n') for l in ls]
    data2_bad_val_original = []
    data2_bad_val_seggt = []
    for i in range(len(data2_bad_val)):
        data2_bad_val_original.append('../BUS/data2/original/C' + data2_bad_val[i][1:] + '.png')
        data2_bad_val_seggt.append('../BUS/data2/GT/C' + data2_bad_val[i][1:]+ '.png')

    val_good = []
    val_bad = []
    val_good.extend(data1_good_val_original)
    val_good.extend(data2_good_val_original)
    val_bad.extend(data1_bad_val_original)
    val_bad.extend(data2_bad_val_original)
    val_good_gt = []
    val_bad_gt = []
    val_good_gt.extend(data1_good_val_seggt)
    val_good_gt.extend(data2_good_val_seggt)
    val_bad_gt.extend(data1_bad_val_seggt)
    val_bad_gt.extend(data2_bad_val_seggt)

    val_name = []
    val_gt = []
    # path and file name of images in the validation set
    val_name.extend(val_good)
    val_name.extend(val_bad)
    val_name.extend(seg_test_add)
    # segmentation gt of validation set
    val_gt.extend(val_good_gt)
    val_gt.extend(val_bad_gt)
    val_gt.extend(seg_test_gt_add)
    # number of good and bad images
    L1_val = np.ones(len(val_good_gt))
    L0_val = np.zeros(len(val_bad_gt))
    # number of images in validation set
    num_val = len(val_good_gt) + len(val_bad_gt)

    # compute accuracy of the validation set
    good_good = 0
    good_bad = 0
    bad_good = 0
    bad_bad = 0
    overlaps = []

    # data1 good
    for index, name in enumerate(data1_good_val_original):
        # path and file name of current image
        img_path = data1_good_val_original[index]
        # segmentation ground truth
        label_path = data1_good_val_seggt[index]

        palette = VOCPalette(nb_class=nb_class)
        imgorg = Image.open(img_path)
        imglab = Image.open(label_path)
        #imgorg = imgorg.convert('RGB')

        img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS)
        img_seg_arr = np.array(img_seg)
        img_seg_arr = img_seg_arr / 127.5 - 1
        img_seg_arr = np.expand_dims(img_seg_arr, 0)
        #img_seg_arr = np.expand_dims(img_seg_arr, -1)
        # predict results
        pred = pair_model.predict(img_seg_arr)
        seg_r = seg_result(pred)
        seg_res = cv2.resize(seg_r, dsize=(imgorg.size[0], imgorg.size[1]), interpolation=cv2.INTER_LINEAR)
        # segmentation accuracy
        iou = seg_accuracy(seg_res, label_path, image_shape)
        overlaps.append(iou)

        # # plot the predicted results.
        # PIL_img_pal = palette.genlabelpal(seg_r)
        # PIL_img_pal = PIL_img_pal.resize((imgorg.size[0], imgorg.size[1]), Image.ANTIALIAS)
        # plt.ion()
        # plt.figure('Multi task')
        # # plt.suptitle(img_path + '\n' + 'Class:' + class_num[cls_r])
        # plt.subplot(1, 3, 1), plt.title('org')
        # plt.imshow(imgorg), plt.axis('off')
        # plt.subplot(1, 3, 2), plt.title('segmentation result')
        # plt.imshow(PIL_img_pal), plt.axis('off')
        # plt.subplot(1, 3, 3), plt.title('label')
        # plt.imshow(imglab), plt.axis('off')
        # plt.show()
    a = 1

    # data1 bad
    for index, name in enumerate(data1_bad_val_original):
        # path and file name of current image
        img_path = data1_bad_val_original[index]
        # segmentation ground truth
        label_path = data1_bad_val_seggt[index]

        palette = VOCPalette(nb_class=nb_class)
        imgorg = Image.open(img_path)
        imglab = Image.open(label_path)
        #imgorg = imgorg.convert('RGB')

        img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS)

        img_seg_arr = np.array(img_seg)
        img_seg_arr = img_seg_arr / 127.5 - 1
        img_seg_arr = np.expand_dims(img_seg_arr, 0)
        #img_seg_arr = np.expand_dims(img_seg_arr, -1)
        # predict results
        pred = pair_model.predict(img_seg_arr)
        seg_r = seg_result(pred)
        seg_res = cv2.resize(seg_r, dsize=(imgorg.size[0], imgorg.size[1]), interpolation=cv2.INTER_LINEAR)
        # segmentation accuracy
        iou = seg_accuracy(seg_res, label_path, image_shape)
        overlaps.append(iou)
    a = 1

    # data2 good
    for index, name in enumerate(data2_good_val_original):
        # path and file name of current image
        img_path = data2_good_val_original[index]
        # segmentation ground truth
        label_path = data2_good_val_seggt[index]

        palette = VOCPalette(nb_class=nb_class)
        imgorg = Image.open(img_path)
        imglab = Image.open(label_path)
        #imgorg = imgorg.convert('RGB')
        img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS)

        img_seg_arr = np.array(img_seg)
        img_seg_arr = img_seg_arr / 127.5 - 1
        img_seg_arr = np.expand_dims(img_seg_arr, 0)
        #img_seg_arr = np.expand_dims(img_seg_arr, -1)
        # predict results
        pred = pair_model.predict( img_seg_arr)
        seg_r = seg_result(pred)
        seg_res = cv2.resize(seg_r, dsize=(imgorg.size[0], imgorg.size[1]), interpolation=cv2.INTER_LINEAR)
        # segmentation accuracy
        iou = seg_accuracy(seg_res, label_path, image_shape)
        overlaps.append(iou)
    a = 1

    # data2 bad
    for index, name in enumerate(data2_bad_val_original):
        # path and file name of current image
        img_path = data2_bad_val_original[index]
        # segmentation ground truth
        label_path = data2_bad_val_seggt[index]

        palette = VOCPalette(nb_class=nb_class)
        imgorg = Image.open(img_path)
        imglab = Image.open(label_path)
        #imgorg = imgorg.convert('RGB')
        img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS)

        img_seg_arr = np.array(img_seg)
        img_seg_arr = img_seg_arr / 127.5 - 1

        img_seg_arr = np.expand_dims(img_seg_arr, 0)
        #img_seg_arr = np.expand_dims(img_seg_arr, -1)
        # predict results
        pred = pair_model.predict(img_seg_arr)
        seg_r = seg_result(pred)
        seg_res = cv2.resize(seg_r, dsize=(imgorg.size[0], imgorg.size[1]), interpolation=cv2.INTER_LINEAR)
        # segmentation accuracy
        iou = seg_accuracy(seg_res, label_path, image_shape)
        overlaps.append(iou)
    a = 1

    # added dataset 3
    for index, name in enumerate(seg_test_add):
        # path and file name of current image
        img_path = seg_test_add[index]
        # segmentation ground truth
        label_path = seg_test_gt_add[index]

        palette = VOCPalette(nb_class=nb_class)
        imgorg = Image.open(img_path)
        imglab = Image.open(label_path)
        # imgorg = imgorg.convert('RGB')
        img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS)

        img_seg_arr = np.array(img_seg)
        img_seg_arr = img_seg_arr / 127.5 - 1

        img_seg_arr = np.expand_dims(img_seg_arr, 0)
        #img_seg_arr = np.expand_dims(img_seg_arr, -1)
        # predict results
        pred = pair_model.predict(img_seg_arr)
        seg_r = seg_result(pred)
        seg_res = cv2.resize(seg_r, dsize=(imgorg.size[0], imgorg.size[1]), interpolation=cv2.INTER_LINEAR)
        # segmentation accuracy
        iou = seg_accuracy(seg_res, label_path, image_shape)
        overlaps.append(iou)
    a = 1

    # segmentation accuracy
    mOverlaps = np.mean(overlaps)
    print("mOverlaps: ", mOverlaps)
    a = 1
Exemple #4
0
def test_cls_model_handler():
    class_num = ['Malignant', 'Benign']
    seg_img_width = 128
    seg_img_height = 128
    cls_img_width = 128
    cls_img_height = 128
    seg_loss_weight = 0.5
    cls_loss_weight = 0.5
    nb_class = 2
    image_shape = (seg_img_height, seg_img_width, 1)

    # load saved model
    model_name = './log/test_model_checkpoint_weight.h5'
    pair_model = create_cls_model(cls_width=cls_img_width,
                                  cls_height=cls_img_height)
    try:
        pair_model.load_weights(model_name)
    except:
        print("You must train model and get weight before test.")

    # # paths to validation set
    # img_path = '../BUS/data2/original/Case27.png'
    # labe_path = '../BUS/data2/GT/Case27.png'
    #
    # palette = VOCPalette(nb_class=nb_class)
    # imgorg = Image.open(img_path)
    # imglab = Image.open(labe_path)
    # imgorg = imgorg.convert('RGB')
    # img_cls = imgorg.resize((cls_img_width, cls_img_height), Image.ANTIALIAS)
    # img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS)
    # img_cls_arr = np.array(img_cls)
    # img_cls_arr = img_cls_arr / 127.5 - 1
    # img_seg_arr = np.array(img_seg)
    # img_seg_arr = img_seg_arr / 127.5 - 1
    # img_cls_arr = np.expand_dims(img_cls_arr, 0)
    # img_seg_arr = np.expand_dims(img_seg_arr, 0)
    # # predict results
    # pred = pair_model.predict([img_cls_arr, img_seg_arr])
    # cls_r = cls_result(pred[0])
    # seg_r = seg_result(pred[1])
    # # plot the predicted results.
    # PIL_img_pal = palette.genlabelpal(seg_r)
    # PIL_img_pal = PIL_img_pal.resize((imgorg.size[0], imgorg.size[1]), Image.ANTIALIAS)
    # plt.ion()
    # plt.figure('Multi task')
    # plt.suptitle(img_path + '\n' + 'Class:' + class_num[cls_r])
    # plt.subplot(1, 3, 1), plt.title('org')
    # plt.imshow(imgorg), plt.axis('off')
    # plt.subplot(1, 3, 2), plt.title('segmentation result')
    # plt.imshow(PIL_img_pal), plt.axis('off')
    # plt.subplot(1, 3, 3), plt.title('label')
    # plt.imshow(imglab), plt.axis('off')
    # plt.show()

    # collect validation data
    data1 = '../BUS/data1/'
    data2 = '../BUS/data2/'
    validation_data_dir_data1_good = data1 + 'data1_good_val.txt'
    validation_data_dir_data1_bad = data1 + 'data1_bad_val.txt'
    validation_data_dir_data2_good = data2 + 'data2_good_val.txt'
    validation_data_dir_data2_bad = data2 + 'data2_bad_val.txt'

    # data1 good
    with open(validation_data_dir_data1_good, "r") as f:
        ls = f.readlines()
    data1_good_val = [l.rstrip('\n') for l in ls]
    data1_good_val_original = []
    data1_good_val_seggt = []
    for i in range(len(data1_good_val)):
        data1_good_val_original.append('../BUS/data1/original/' +
                                       data1_good_val[i] + '.png')
        data1_good_val_seggt.append('../BUS/data1/GT/' + data1_good_val[i] +
                                    '.png')
    # data1 bad
    with open(validation_data_dir_data1_bad, "r") as f:
        ls = f.readlines()
    data1_bad_val = [l.rstrip('\n') for l in ls]
    data1_bad_val_original = []
    data1_bad_val_seggt = []
    for i in range(len(data1_bad_val)):
        data1_bad_val_original.append('../BUS/data1/original/' +
                                      data1_bad_val[i] + '.png')
        data1_bad_val_seggt.append('../BUS/data1/GT/' + data1_bad_val[i] +
                                   '.png')

    # data2 good
    with open(validation_data_dir_data2_good, "r") as f:
        ls = f.readlines()
    data2_good_val = [l.rstrip('\n') for l in ls]
    data2_good_val_original = []
    data2_good_val_seggt = []
    for i in range(len(data2_good_val)):
        data2_good_val_original.append('../BUS/data2/original/C' +
                                       data2_good_val[i][1:] + '.png')
        data2_good_val_seggt.append('../BUS/data2/GT/C' +
                                    data2_good_val[i][1:] + '.png')

    # data2 bad
    with open(validation_data_dir_data2_bad, "r") as f:
        ls = f.readlines()
    data2_bad_val = [l.rstrip('\n') for l in ls]
    data2_bad_val_original = []
    data2_bad_val_seggt = []
    for i in range(len(data2_bad_val)):
        data2_bad_val_original.append('../BUS/data2/original/C' +
                                      data2_bad_val[i][1:] + '.png')
        data2_bad_val_seggt.append('../BUS/data2/GT/C' + data2_bad_val[i][1:] +
                                   '.png')

    val_good = []
    val_bad = []
    val_good.extend(data1_good_val_original)
    val_good.extend(data2_good_val_original)
    val_bad.extend(data1_bad_val_original)
    val_bad.extend(data2_bad_val_original)
    val_good_gt = []
    val_bad_gt = []
    val_good_gt.extend(data1_good_val_seggt)
    val_good_gt.extend(data2_good_val_seggt)
    val_bad_gt.extend(data1_bad_val_seggt)
    val_bad_gt.extend(data2_bad_val_seggt)

    val_name = []
    val_gt = []
    # path and file name of images in the validation set
    val_name.extend(val_good)
    val_name.extend(val_bad)
    # segmentation gt of validation set
    val_gt.extend(val_good_gt)
    val_gt.extend(val_bad_gt)
    # number of good and bad images
    L1_val = np.ones(len(val_good_gt))
    L0_val = np.zeros(len(val_bad_gt))
    # number of images in validation set
    num_val = len(val_good_gt) + len(val_bad_gt)

    # compute accuracy of the validation set
    good_good = 0
    good_bad = 0
    bad_good = 0
    bad_bad = 0
    overlaps = []

    # data1 good
    for index, name in enumerate(data1_good_val_original):
        # path and file name of current image
        img_path = data1_good_val_original[index]
        # segmentation ground truth
        label_path = data1_good_val_seggt[index]

        palette = VOCPalette(nb_class=nb_class)
        imgorg = Image.open(img_path)
        imglab = Image.open(label_path)
        imgorg = imgorg.convert('RGB')
        img_cls = imgorg.resize((cls_img_width, cls_img_height),
                                Image.ANTIALIAS)
        img_seg = imgorg.resize((seg_img_width, seg_img_height),
                                Image.ANTIALIAS)
        img_cls_arr = np.array(img_cls)
        img_cls_arr = img_cls_arr / 127.5 - 1
        img_seg_arr = np.array(img_seg)
        img_seg_arr = img_seg_arr / 127.5 - 1
        img_cls_arr = np.expand_dims(img_cls_arr, 0)
        img_seg_arr = np.expand_dims(img_seg_arr, 0)
        # predict results
        pred = pair_model.predict(img_cls_arr)
        cls_r = cls_result(pred)

        # classification accuracy
        label_1 = np.uint8(1)
        label_0 = np.uint8(0)
        # label: good and predict:good
        if cls_r == label_1:
            good_good = good_good + 1
        elif cls_r == label_0:
            good_bad = good_bad + 1

        # # plot the predicted results.
        # PIL_img_pal = palette.genlabelpal(seg_r)
        # PIL_img_pal = PIL_img_pal.resize((imgorg.size[0], imgorg.size[1]), Image.ANTIALIAS)
        # plt.ion()
        # plt.figure('Multi task')
        # plt.suptitle(img_path + '\n' + 'Class:' + class_num[cls_r])
        # plt.subplot(1, 3, 1), plt.title('org')
        # plt.imshow(imgorg), plt.axis('off')
        # plt.subplot(1, 3, 2), plt.title('segmentation result')
        # plt.imshow(PIL_img_pal), plt.axis('off')
        # plt.subplot(1, 3, 3), plt.title('label')
        # plt.imshow(imglab), plt.axis('off')
        # plt.show()
    a = 1

    # data1 bad
    for index, name in enumerate(data1_bad_val_original):
        # path and file name of current image
        img_path = data1_bad_val_original[index]
        # segmentation ground truth
        label_path = data1_bad_val_seggt[index]

        palette = VOCPalette(nb_class=nb_class)
        imgorg = Image.open(img_path)
        imglab = Image.open(label_path)
        imgorg = imgorg.convert('RGB')
        img_cls = imgorg.resize((cls_img_width, cls_img_height),
                                Image.ANTIALIAS)
        img_seg = imgorg.resize((seg_img_width, seg_img_height),
                                Image.ANTIALIAS)
        img_cls_arr = np.array(img_cls)
        img_cls_arr = img_cls_arr / 127.5 - 1
        img_seg_arr = np.array(img_seg)
        img_seg_arr = img_seg_arr / 127.5 - 1
        img_cls_arr = np.expand_dims(img_cls_arr, 0)
        img_seg_arr = np.expand_dims(img_seg_arr, 0)
        # predict results
        pred = pair_model.predict(img_cls_arr)
        cls_r = cls_result(pred)

        label_1 = np.uint8(1)
        label_0 = np.uint8(0)
        # label: good and predict:good
        if cls_r == label_1:
            bad_good = bad_good + 1
        elif cls_r == label_0:
            bad_bad = bad_bad + 1

    a = 1

    # data2 good
    for index, name in enumerate(data2_good_val_original):
        # path and file name of current image
        img_path = data2_good_val_original[index]
        # segmentation ground truth
        label_path = data2_good_val_seggt[index]

        palette = VOCPalette(nb_class=nb_class)
        imgorg = Image.open(img_path)
        imglab = Image.open(label_path)
        imgorg = imgorg.convert('RGB')
        img_cls = imgorg.resize((cls_img_width, cls_img_height),
                                Image.ANTIALIAS)
        img_seg = imgorg.resize((seg_img_width, seg_img_height),
                                Image.ANTIALIAS)
        img_cls_arr = np.array(img_cls)
        img_cls_arr = img_cls_arr / 127.5 - 1
        img_seg_arr = np.array(img_seg)
        img_seg_arr = img_seg_arr / 127.5 - 1
        img_cls_arr = np.expand_dims(img_cls_arr, 0)
        img_seg_arr = np.expand_dims(img_seg_arr, 0)
        # predict results
        pred = pair_model.predict(img_cls_arr)
        cls_r = cls_result(pred)

        label_1 = np.uint8(1)
        label_0 = np.uint8(0)
        # label: good and predict:good
        if cls_r == label_1:
            good_good = good_good + 1
        elif cls_r == label_0:
            good_bad = good_bad + 1

    a = 1

    # data2 bad
    for index, name in enumerate(data2_bad_val_original):
        # path and file name of current image
        img_path = data2_bad_val_original[index]
        # segmentation ground truth
        label_path = data2_bad_val_seggt[index]

        palette = VOCPalette(nb_class=nb_class)
        imgorg = Image.open(img_path)
        imglab = Image.open(label_path)
        imgorg = imgorg.convert('RGB')
        img_cls = imgorg.resize((cls_img_width, cls_img_height),
                                Image.ANTIALIAS)
        img_seg = imgorg.resize((seg_img_width, seg_img_height),
                                Image.ANTIALIAS)
        img_cls_arr = np.array(img_cls)
        img_cls_arr = img_cls_arr / 127.5 - 1
        img_seg_arr = np.array(img_seg)
        img_seg_arr = img_seg_arr / 127.5 - 1
        img_cls_arr = np.expand_dims(img_cls_arr, 0)
        img_seg_arr = np.expand_dims(img_seg_arr, 0)
        # predict results
        pred = pair_model.predict(img_cls_arr)
        cls_r = cls_result(pred)

        label_1 = np.uint8(1)
        label_0 = np.uint8(0)
        # label: good and predict:good
        if cls_r == label_1:
            bad_good = bad_good + 1
        elif cls_r == label_0:
            bad_bad = bad_bad + 1

    a = 1

    # classification accuracy
    total = good_good + good_bad + bad_good + bad_bad
    accuracy = (good_good + bad_bad) / total
    sensitivity = good_good / (good_good + good_bad)
    specificity = bad_bad / (bad_good + bad_bad)
    precision = good_good / (good_good + bad_good)

    print("accuracy: ", accuracy)
    print("sensitivity: ", sensitivity)
    print("specificity: ", specificity)
    print("precision: ", precision)
Exemple #5
0
    model = unet(input_shape=(img_height, img_width, channels), num_classes=nb_class,
                 lr_init=1e-3, lr_decay=5e-4, vgg_weight_path=vgg_path)
elif model_name == "fuzzyunet":
    model = fuzzy_unet(input_shape=(img_height, img_width, channels), num_classes=nb_class,
                 lr_init=1e-3, lr_decay=5e-4, vgg_weight_path=vgg_path)
elif model_name == "pspnet":
    model = pspnet50(input_shape=(img_height, img_width, channels), num_classes=nb_class, lr_init=1e-3, lr_decay=5e-4)

# load weights
try:
    model.load_weights(model_name + '_model_weight.h5')
except:
    print("You must train model and get weight before test.")

# Palette, used to show the result
palette = VOCPalette(nb_class=nb_class)
# print the testing image's name
print(img_path)
# read testing image
imgorg = Image.open(img_path)
# read label
imglab = Image.open(label_path)
# resize the input image to the input layer's size
img = imgorg.resize((img_width, img_height), Image.ANTIALIAS)
# convert to numpy array
img_arr = np.array(img)
# Centering helps normalization image (-1 ~ 1 value)
img_arr = img_arr / 127.5 - 1
# batch size is set to one
img_arr = np.expand_dims(img_arr, 0)
# img_arr = np.expand_dims(img_arr, 3)
Exemple #6
0
with open(input_file, "r") as f:
    ls = f.readlines()
namesimg = [l.rstrip('\n') for l in ls]
nb_data_img = len(namesimg)

val_num = math.ceil(nb_data_img * VAL_RATIO)

random.seed(time.time)
#indices = [n for n in range(len(namesimg))]
#random.shuffle(indices)
#print(indices)

random.shuffle(namesimg)

palette = VOCPalette(nb_class=NB_CLASS)
# Make ImageDataGenerator.
x_data_gen = ImageDataGenerator(**x_data_gen_args)
y_data_gen = ImageDataGenerator(**y_data_gen_args)
f_train = open(gen_txt_path + 'train.txt', "a")
f_val = open(gen_txt_path + 'val.txt', "a")
f_test = open(gen_txt_path + 'test_tumor.txt', "a")

for i in range(nb_data_img):
    if i < val_num:
        f_test.writelines(namesimg[i] + "\n")

    Xpath = img_path + "{}.png".format(namesimg[i])
    Ypath = label_path + "{}.png".format(namesimg[i])
    print(Xpath)