def __init__(self, output_path, model_name, img_shape, nb_class): self.epoch = 0 self.output_path = output_path self.model_name = model_name self.img_shape = img_shape self.nb_class = nb_class self.palette = VOCPalette(nb_class=nb_class)
class TrainCheck(Callback): def __init__(self, output_path, model_name, img_shape, nb_class): self.epoch = 0 self.output_path = output_path self.model_name = model_name self.img_shape = img_shape self.nb_class = nb_class self.palette = VOCPalette(nb_class=nb_class) def result_map_to_img(self, res_map): res_map = np.squeeze(res_map) argmax_idx = np.argmax(res_map, axis=2).astype('uint8') return argmax_idx def on_epoch_end(self, epoch, logs={}): self.epoch = epoch + 1 # self.visualize(os.path.join(self.output_path,'test.png')) def visualize(self, path): imgorg = Image.open(path).convert('RGB') img = imgorg.resize((self.img_shape[1], self.img_shape[0]), Image.ANTIALIAS) img_arr = np.array(img) img_arr = img_arr / 127.5 - 1 img_arr = np.expand_dims(img_arr, 0) pred = self.model.predict(img_arr) res_img = self.result_map_to_img(pred[0]) PIL_img_pal = self.palette.genlabelpal(res_img) PIL_img_pal = PIL_img_pal.resize((imgorg.size[0], imgorg.size[1]), Image.ANTIALIAS) PIL_img_pal.save( os.path.join( self.output_path, self.model_name + '_epoch_' + str(self.epoch) + '.png'))
def test_seg_model_handler(): class_num = ['Malignant', 'Benign'] seg_img_width = 256 seg_img_height = 256 seg_loss_weight = 0.5 cls_loss_weight = 0.5 nb_class = 2 image_shape = (seg_img_height, seg_img_width, 3) # load saved model model_name = './log/unet_sobel_model_weight.h5' pair_model = unet(input_shape=(seg_img_height, seg_img_width, 3), num_classes=2, lr_init=1e-4, lr_decay=5e-4, vgg_weight_path=None) try: pair_model.load_weights(model_name) except: print("You must train model and get weight before test.") # # paths to validation set # img_path = '../BUS/data2/original/Case27.png' # labe_path = '../BUS/data2/GT/Case27.png' # # palette = VOCPalette(nb_class=nb_class) # imgorg = Image.open(img_path) # imglab = Image.open(labe_path) # imgorg = imgorg.convert('RGB') # img_cls = imgorg.resize((cls_img_width, cls_img_height), Image.ANTIALIAS) # img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS) # img_cls_arr = np.array(img_cls) # img_cls_arr = img_cls_arr / 127.5 - 1 # img_seg_arr = np.array(img_seg) # img_seg_arr = img_seg_arr / 127.5 - 1 # img_cls_arr = np.expand_dims(img_cls_arr, 0) # img_seg_arr = np.expand_dims(img_seg_arr, 0) # # predict results # pred = pair_model.predict([img_cls_arr, img_seg_arr]) # cls_r = cls_result(pred[0]) # seg_r = seg_result(pred[1]) # # plot the predicted results. # PIL_img_pal = palette.genlabelpal(seg_r) # PIL_img_pal = PIL_img_pal.resize((imgorg.size[0], imgorg.size[1]), Image.ANTIALIAS) # plt.ion() # plt.figure('Multi task') # plt.suptitle(img_path + '\n' + 'Class:' + class_num[cls_r]) # plt.subplot(1, 3, 1), plt.title('org') # plt.imshow(imgorg), plt.axis('off') # plt.subplot(1, 3, 2), plt.title('segmentation result') # plt.imshow(PIL_img_pal), plt.axis('off') # plt.subplot(1, 3, 3), plt.title('label') # plt.imshow(imglab), plt.axis('off') # plt.show() # collect validation data data1 = '../BUS/data1/' data2 = '../BUS/data2/' validation_data_dir_data1_good = data1 + 'data1_good_val.txt' validation_data_dir_data1_bad = data1 + 'data1_bad_val.txt' validation_data_dir_data2_good = data2 + 'data2_good_val.txt' validation_data_dir_data2_bad = data2 + 'data2_bad_val.txt' # Add dataset3 seg_data = '../BUS/data3/' seg_test = seg_data + 'test.txt' # Add dataset 3 # Segment dataset with open(seg_test, "r") as f: ls = f.readlines() seg_test_name = [l.rstrip('\n') for l in ls] seg_test_add = [] seg_test_gt_add = [] for i in range(len(seg_test_name)): seg_test_add.append('../BUS/data3/original/c' + seg_test_name[i][1:] + '.png') seg_test_gt_add.append('../BUS/data3/GT/c' + seg_test_name[i][1:] + '_GT.png') a = 1 # data1 good with open(validation_data_dir_data1_good, "r") as f: ls = f.readlines() data1_good_val = [l.rstrip('\n') for l in ls] data1_good_val_original = [] data1_good_val_seggt = [] for i in range(len(data1_good_val)): data1_good_val_original.append('../BUS/data1/original/' + data1_good_val[i] + '.png') data1_good_val_seggt.append('../BUS/data1/GT/' + data1_good_val[i] + '.png') # data1 bad with open(validation_data_dir_data1_bad, "r") as f: ls = f.readlines() data1_bad_val = [l.rstrip('\n') for l in ls] data1_bad_val_original = [] data1_bad_val_seggt = [] for i in range(len(data1_bad_val)): data1_bad_val_original.append('../BUS/data1/original/' + data1_bad_val[i] + '.png') data1_bad_val_seggt.append('../BUS/data1/GT/' + data1_bad_val[i] + '.png') # data2 good with open(validation_data_dir_data2_good, "r") as f: ls = f.readlines() data2_good_val = [l.rstrip('\n') for l in ls] data2_good_val_original = [] data2_good_val_seggt = [] for i in range(len(data2_good_val)): data2_good_val_original.append('../BUS/data2/original/C' + data2_good_val[i][1:] + '.png') data2_good_val_seggt.append('../BUS/data2/GT/C' + data2_good_val[i][1:] + '.png') # data2 bad with open(validation_data_dir_data2_bad, "r") as f: ls = f.readlines() data2_bad_val = [l.rstrip('\n') for l in ls] data2_bad_val_original = [] data2_bad_val_seggt = [] for i in range(len(data2_bad_val)): data2_bad_val_original.append('../BUS/data2/original/C' + data2_bad_val[i][1:] + '.png') data2_bad_val_seggt.append('../BUS/data2/GT/C' + data2_bad_val[i][1:]+ '.png') val_good = [] val_bad = [] val_good.extend(data1_good_val_original) val_good.extend(data2_good_val_original) val_bad.extend(data1_bad_val_original) val_bad.extend(data2_bad_val_original) val_good_gt = [] val_bad_gt = [] val_good_gt.extend(data1_good_val_seggt) val_good_gt.extend(data2_good_val_seggt) val_bad_gt.extend(data1_bad_val_seggt) val_bad_gt.extend(data2_bad_val_seggt) val_name = [] val_gt = [] # path and file name of images in the validation set val_name.extend(val_good) val_name.extend(val_bad) val_name.extend(seg_test_add) # segmentation gt of validation set val_gt.extend(val_good_gt) val_gt.extend(val_bad_gt) val_gt.extend(seg_test_gt_add) # number of good and bad images L1_val = np.ones(len(val_good_gt)) L0_val = np.zeros(len(val_bad_gt)) # number of images in validation set num_val = len(val_good_gt) + len(val_bad_gt) # compute accuracy of the validation set good_good = 0 good_bad = 0 bad_good = 0 bad_bad = 0 overlaps = [] # data1 good for index, name in enumerate(data1_good_val_original): # path and file name of current image img_path = data1_good_val_original[index] # segmentation ground truth label_path = data1_good_val_seggt[index] palette = VOCPalette(nb_class=nb_class) imgorg = Image.open(img_path) imglab = Image.open(label_path) #imgorg = imgorg.convert('RGB') img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS) img_seg_arr = np.array(img_seg) img_seg_arr = img_seg_arr / 127.5 - 1 img_seg_arr = np.expand_dims(img_seg_arr, 0) #img_seg_arr = np.expand_dims(img_seg_arr, -1) # predict results pred = pair_model.predict(img_seg_arr) seg_r = seg_result(pred) seg_res = cv2.resize(seg_r, dsize=(imgorg.size[0], imgorg.size[1]), interpolation=cv2.INTER_LINEAR) # segmentation accuracy iou = seg_accuracy(seg_res, label_path, image_shape) overlaps.append(iou) # # plot the predicted results. # PIL_img_pal = palette.genlabelpal(seg_r) # PIL_img_pal = PIL_img_pal.resize((imgorg.size[0], imgorg.size[1]), Image.ANTIALIAS) # plt.ion() # plt.figure('Multi task') # # plt.suptitle(img_path + '\n' + 'Class:' + class_num[cls_r]) # plt.subplot(1, 3, 1), plt.title('org') # plt.imshow(imgorg), plt.axis('off') # plt.subplot(1, 3, 2), plt.title('segmentation result') # plt.imshow(PIL_img_pal), plt.axis('off') # plt.subplot(1, 3, 3), plt.title('label') # plt.imshow(imglab), plt.axis('off') # plt.show() a = 1 # data1 bad for index, name in enumerate(data1_bad_val_original): # path and file name of current image img_path = data1_bad_val_original[index] # segmentation ground truth label_path = data1_bad_val_seggt[index] palette = VOCPalette(nb_class=nb_class) imgorg = Image.open(img_path) imglab = Image.open(label_path) #imgorg = imgorg.convert('RGB') img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS) img_seg_arr = np.array(img_seg) img_seg_arr = img_seg_arr / 127.5 - 1 img_seg_arr = np.expand_dims(img_seg_arr, 0) #img_seg_arr = np.expand_dims(img_seg_arr, -1) # predict results pred = pair_model.predict(img_seg_arr) seg_r = seg_result(pred) seg_res = cv2.resize(seg_r, dsize=(imgorg.size[0], imgorg.size[1]), interpolation=cv2.INTER_LINEAR) # segmentation accuracy iou = seg_accuracy(seg_res, label_path, image_shape) overlaps.append(iou) a = 1 # data2 good for index, name in enumerate(data2_good_val_original): # path and file name of current image img_path = data2_good_val_original[index] # segmentation ground truth label_path = data2_good_val_seggt[index] palette = VOCPalette(nb_class=nb_class) imgorg = Image.open(img_path) imglab = Image.open(label_path) #imgorg = imgorg.convert('RGB') img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS) img_seg_arr = np.array(img_seg) img_seg_arr = img_seg_arr / 127.5 - 1 img_seg_arr = np.expand_dims(img_seg_arr, 0) #img_seg_arr = np.expand_dims(img_seg_arr, -1) # predict results pred = pair_model.predict( img_seg_arr) seg_r = seg_result(pred) seg_res = cv2.resize(seg_r, dsize=(imgorg.size[0], imgorg.size[1]), interpolation=cv2.INTER_LINEAR) # segmentation accuracy iou = seg_accuracy(seg_res, label_path, image_shape) overlaps.append(iou) a = 1 # data2 bad for index, name in enumerate(data2_bad_val_original): # path and file name of current image img_path = data2_bad_val_original[index] # segmentation ground truth label_path = data2_bad_val_seggt[index] palette = VOCPalette(nb_class=nb_class) imgorg = Image.open(img_path) imglab = Image.open(label_path) #imgorg = imgorg.convert('RGB') img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS) img_seg_arr = np.array(img_seg) img_seg_arr = img_seg_arr / 127.5 - 1 img_seg_arr = np.expand_dims(img_seg_arr, 0) #img_seg_arr = np.expand_dims(img_seg_arr, -1) # predict results pred = pair_model.predict(img_seg_arr) seg_r = seg_result(pred) seg_res = cv2.resize(seg_r, dsize=(imgorg.size[0], imgorg.size[1]), interpolation=cv2.INTER_LINEAR) # segmentation accuracy iou = seg_accuracy(seg_res, label_path, image_shape) overlaps.append(iou) a = 1 # added dataset 3 for index, name in enumerate(seg_test_add): # path and file name of current image img_path = seg_test_add[index] # segmentation ground truth label_path = seg_test_gt_add[index] palette = VOCPalette(nb_class=nb_class) imgorg = Image.open(img_path) imglab = Image.open(label_path) # imgorg = imgorg.convert('RGB') img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS) img_seg_arr = np.array(img_seg) img_seg_arr = img_seg_arr / 127.5 - 1 img_seg_arr = np.expand_dims(img_seg_arr, 0) #img_seg_arr = np.expand_dims(img_seg_arr, -1) # predict results pred = pair_model.predict(img_seg_arr) seg_r = seg_result(pred) seg_res = cv2.resize(seg_r, dsize=(imgorg.size[0], imgorg.size[1]), interpolation=cv2.INTER_LINEAR) # segmentation accuracy iou = seg_accuracy(seg_res, label_path, image_shape) overlaps.append(iou) a = 1 # segmentation accuracy mOverlaps = np.mean(overlaps) print("mOverlaps: ", mOverlaps) a = 1
def test_cls_model_handler(): class_num = ['Malignant', 'Benign'] seg_img_width = 128 seg_img_height = 128 cls_img_width = 128 cls_img_height = 128 seg_loss_weight = 0.5 cls_loss_weight = 0.5 nb_class = 2 image_shape = (seg_img_height, seg_img_width, 1) # load saved model model_name = './log/test_model_checkpoint_weight.h5' pair_model = create_cls_model(cls_width=cls_img_width, cls_height=cls_img_height) try: pair_model.load_weights(model_name) except: print("You must train model and get weight before test.") # # paths to validation set # img_path = '../BUS/data2/original/Case27.png' # labe_path = '../BUS/data2/GT/Case27.png' # # palette = VOCPalette(nb_class=nb_class) # imgorg = Image.open(img_path) # imglab = Image.open(labe_path) # imgorg = imgorg.convert('RGB') # img_cls = imgorg.resize((cls_img_width, cls_img_height), Image.ANTIALIAS) # img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS) # img_cls_arr = np.array(img_cls) # img_cls_arr = img_cls_arr / 127.5 - 1 # img_seg_arr = np.array(img_seg) # img_seg_arr = img_seg_arr / 127.5 - 1 # img_cls_arr = np.expand_dims(img_cls_arr, 0) # img_seg_arr = np.expand_dims(img_seg_arr, 0) # # predict results # pred = pair_model.predict([img_cls_arr, img_seg_arr]) # cls_r = cls_result(pred[0]) # seg_r = seg_result(pred[1]) # # plot the predicted results. # PIL_img_pal = palette.genlabelpal(seg_r) # PIL_img_pal = PIL_img_pal.resize((imgorg.size[0], imgorg.size[1]), Image.ANTIALIAS) # plt.ion() # plt.figure('Multi task') # plt.suptitle(img_path + '\n' + 'Class:' + class_num[cls_r]) # plt.subplot(1, 3, 1), plt.title('org') # plt.imshow(imgorg), plt.axis('off') # plt.subplot(1, 3, 2), plt.title('segmentation result') # plt.imshow(PIL_img_pal), plt.axis('off') # plt.subplot(1, 3, 3), plt.title('label') # plt.imshow(imglab), plt.axis('off') # plt.show() # collect validation data data1 = '../BUS/data1/' data2 = '../BUS/data2/' validation_data_dir_data1_good = data1 + 'data1_good_val.txt' validation_data_dir_data1_bad = data1 + 'data1_bad_val.txt' validation_data_dir_data2_good = data2 + 'data2_good_val.txt' validation_data_dir_data2_bad = data2 + 'data2_bad_val.txt' # data1 good with open(validation_data_dir_data1_good, "r") as f: ls = f.readlines() data1_good_val = [l.rstrip('\n') for l in ls] data1_good_val_original = [] data1_good_val_seggt = [] for i in range(len(data1_good_val)): data1_good_val_original.append('../BUS/data1/original/' + data1_good_val[i] + '.png') data1_good_val_seggt.append('../BUS/data1/GT/' + data1_good_val[i] + '.png') # data1 bad with open(validation_data_dir_data1_bad, "r") as f: ls = f.readlines() data1_bad_val = [l.rstrip('\n') for l in ls] data1_bad_val_original = [] data1_bad_val_seggt = [] for i in range(len(data1_bad_val)): data1_bad_val_original.append('../BUS/data1/original/' + data1_bad_val[i] + '.png') data1_bad_val_seggt.append('../BUS/data1/GT/' + data1_bad_val[i] + '.png') # data2 good with open(validation_data_dir_data2_good, "r") as f: ls = f.readlines() data2_good_val = [l.rstrip('\n') for l in ls] data2_good_val_original = [] data2_good_val_seggt = [] for i in range(len(data2_good_val)): data2_good_val_original.append('../BUS/data2/original/C' + data2_good_val[i][1:] + '.png') data2_good_val_seggt.append('../BUS/data2/GT/C' + data2_good_val[i][1:] + '.png') # data2 bad with open(validation_data_dir_data2_bad, "r") as f: ls = f.readlines() data2_bad_val = [l.rstrip('\n') for l in ls] data2_bad_val_original = [] data2_bad_val_seggt = [] for i in range(len(data2_bad_val)): data2_bad_val_original.append('../BUS/data2/original/C' + data2_bad_val[i][1:] + '.png') data2_bad_val_seggt.append('../BUS/data2/GT/C' + data2_bad_val[i][1:] + '.png') val_good = [] val_bad = [] val_good.extend(data1_good_val_original) val_good.extend(data2_good_val_original) val_bad.extend(data1_bad_val_original) val_bad.extend(data2_bad_val_original) val_good_gt = [] val_bad_gt = [] val_good_gt.extend(data1_good_val_seggt) val_good_gt.extend(data2_good_val_seggt) val_bad_gt.extend(data1_bad_val_seggt) val_bad_gt.extend(data2_bad_val_seggt) val_name = [] val_gt = [] # path and file name of images in the validation set val_name.extend(val_good) val_name.extend(val_bad) # segmentation gt of validation set val_gt.extend(val_good_gt) val_gt.extend(val_bad_gt) # number of good and bad images L1_val = np.ones(len(val_good_gt)) L0_val = np.zeros(len(val_bad_gt)) # number of images in validation set num_val = len(val_good_gt) + len(val_bad_gt) # compute accuracy of the validation set good_good = 0 good_bad = 0 bad_good = 0 bad_bad = 0 overlaps = [] # data1 good for index, name in enumerate(data1_good_val_original): # path and file name of current image img_path = data1_good_val_original[index] # segmentation ground truth label_path = data1_good_val_seggt[index] palette = VOCPalette(nb_class=nb_class) imgorg = Image.open(img_path) imglab = Image.open(label_path) imgorg = imgorg.convert('RGB') img_cls = imgorg.resize((cls_img_width, cls_img_height), Image.ANTIALIAS) img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS) img_cls_arr = np.array(img_cls) img_cls_arr = img_cls_arr / 127.5 - 1 img_seg_arr = np.array(img_seg) img_seg_arr = img_seg_arr / 127.5 - 1 img_cls_arr = np.expand_dims(img_cls_arr, 0) img_seg_arr = np.expand_dims(img_seg_arr, 0) # predict results pred = pair_model.predict(img_cls_arr) cls_r = cls_result(pred) # classification accuracy label_1 = np.uint8(1) label_0 = np.uint8(0) # label: good and predict:good if cls_r == label_1: good_good = good_good + 1 elif cls_r == label_0: good_bad = good_bad + 1 # # plot the predicted results. # PIL_img_pal = palette.genlabelpal(seg_r) # PIL_img_pal = PIL_img_pal.resize((imgorg.size[0], imgorg.size[1]), Image.ANTIALIAS) # plt.ion() # plt.figure('Multi task') # plt.suptitle(img_path + '\n' + 'Class:' + class_num[cls_r]) # plt.subplot(1, 3, 1), plt.title('org') # plt.imshow(imgorg), plt.axis('off') # plt.subplot(1, 3, 2), plt.title('segmentation result') # plt.imshow(PIL_img_pal), plt.axis('off') # plt.subplot(1, 3, 3), plt.title('label') # plt.imshow(imglab), plt.axis('off') # plt.show() a = 1 # data1 bad for index, name in enumerate(data1_bad_val_original): # path and file name of current image img_path = data1_bad_val_original[index] # segmentation ground truth label_path = data1_bad_val_seggt[index] palette = VOCPalette(nb_class=nb_class) imgorg = Image.open(img_path) imglab = Image.open(label_path) imgorg = imgorg.convert('RGB') img_cls = imgorg.resize((cls_img_width, cls_img_height), Image.ANTIALIAS) img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS) img_cls_arr = np.array(img_cls) img_cls_arr = img_cls_arr / 127.5 - 1 img_seg_arr = np.array(img_seg) img_seg_arr = img_seg_arr / 127.5 - 1 img_cls_arr = np.expand_dims(img_cls_arr, 0) img_seg_arr = np.expand_dims(img_seg_arr, 0) # predict results pred = pair_model.predict(img_cls_arr) cls_r = cls_result(pred) label_1 = np.uint8(1) label_0 = np.uint8(0) # label: good and predict:good if cls_r == label_1: bad_good = bad_good + 1 elif cls_r == label_0: bad_bad = bad_bad + 1 a = 1 # data2 good for index, name in enumerate(data2_good_val_original): # path and file name of current image img_path = data2_good_val_original[index] # segmentation ground truth label_path = data2_good_val_seggt[index] palette = VOCPalette(nb_class=nb_class) imgorg = Image.open(img_path) imglab = Image.open(label_path) imgorg = imgorg.convert('RGB') img_cls = imgorg.resize((cls_img_width, cls_img_height), Image.ANTIALIAS) img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS) img_cls_arr = np.array(img_cls) img_cls_arr = img_cls_arr / 127.5 - 1 img_seg_arr = np.array(img_seg) img_seg_arr = img_seg_arr / 127.5 - 1 img_cls_arr = np.expand_dims(img_cls_arr, 0) img_seg_arr = np.expand_dims(img_seg_arr, 0) # predict results pred = pair_model.predict(img_cls_arr) cls_r = cls_result(pred) label_1 = np.uint8(1) label_0 = np.uint8(0) # label: good and predict:good if cls_r == label_1: good_good = good_good + 1 elif cls_r == label_0: good_bad = good_bad + 1 a = 1 # data2 bad for index, name in enumerate(data2_bad_val_original): # path and file name of current image img_path = data2_bad_val_original[index] # segmentation ground truth label_path = data2_bad_val_seggt[index] palette = VOCPalette(nb_class=nb_class) imgorg = Image.open(img_path) imglab = Image.open(label_path) imgorg = imgorg.convert('RGB') img_cls = imgorg.resize((cls_img_width, cls_img_height), Image.ANTIALIAS) img_seg = imgorg.resize((seg_img_width, seg_img_height), Image.ANTIALIAS) img_cls_arr = np.array(img_cls) img_cls_arr = img_cls_arr / 127.5 - 1 img_seg_arr = np.array(img_seg) img_seg_arr = img_seg_arr / 127.5 - 1 img_cls_arr = np.expand_dims(img_cls_arr, 0) img_seg_arr = np.expand_dims(img_seg_arr, 0) # predict results pred = pair_model.predict(img_cls_arr) cls_r = cls_result(pred) label_1 = np.uint8(1) label_0 = np.uint8(0) # label: good and predict:good if cls_r == label_1: bad_good = bad_good + 1 elif cls_r == label_0: bad_bad = bad_bad + 1 a = 1 # classification accuracy total = good_good + good_bad + bad_good + bad_bad accuracy = (good_good + bad_bad) / total sensitivity = good_good / (good_good + good_bad) specificity = bad_bad / (bad_good + bad_bad) precision = good_good / (good_good + bad_good) print("accuracy: ", accuracy) print("sensitivity: ", sensitivity) print("specificity: ", specificity) print("precision: ", precision)
model = unet(input_shape=(img_height, img_width, channels), num_classes=nb_class, lr_init=1e-3, lr_decay=5e-4, vgg_weight_path=vgg_path) elif model_name == "fuzzyunet": model = fuzzy_unet(input_shape=(img_height, img_width, channels), num_classes=nb_class, lr_init=1e-3, lr_decay=5e-4, vgg_weight_path=vgg_path) elif model_name == "pspnet": model = pspnet50(input_shape=(img_height, img_width, channels), num_classes=nb_class, lr_init=1e-3, lr_decay=5e-4) # load weights try: model.load_weights(model_name + '_model_weight.h5') except: print("You must train model and get weight before test.") # Palette, used to show the result palette = VOCPalette(nb_class=nb_class) # print the testing image's name print(img_path) # read testing image imgorg = Image.open(img_path) # read label imglab = Image.open(label_path) # resize the input image to the input layer's size img = imgorg.resize((img_width, img_height), Image.ANTIALIAS) # convert to numpy array img_arr = np.array(img) # Centering helps normalization image (-1 ~ 1 value) img_arr = img_arr / 127.5 - 1 # batch size is set to one img_arr = np.expand_dims(img_arr, 0) # img_arr = np.expand_dims(img_arr, 3)
with open(input_file, "r") as f: ls = f.readlines() namesimg = [l.rstrip('\n') for l in ls] nb_data_img = len(namesimg) val_num = math.ceil(nb_data_img * VAL_RATIO) random.seed(time.time) #indices = [n for n in range(len(namesimg))] #random.shuffle(indices) #print(indices) random.shuffle(namesimg) palette = VOCPalette(nb_class=NB_CLASS) # Make ImageDataGenerator. x_data_gen = ImageDataGenerator(**x_data_gen_args) y_data_gen = ImageDataGenerator(**y_data_gen_args) f_train = open(gen_txt_path + 'train.txt', "a") f_val = open(gen_txt_path + 'val.txt', "a") f_test = open(gen_txt_path + 'test_tumor.txt', "a") for i in range(nb_data_img): if i < val_num: f_test.writelines(namesimg[i] + "\n") Xpath = img_path + "{}.png".format(namesimg[i]) Ypath = label_path + "{}.png".format(namesimg[i]) print(Xpath)