def test_model():
    model_name = "fcn_8"
    h = 224
    w = 256
    n_c = 100
    check_path = "/tmp/%d" % (random.randint(0, 199999))

    m = models.model_from_name[model_name](n_c, input_height=h, input_width=w)

    m.train(train_images=tr_im,
            train_annotations=tr_an,
            steps_per_epoch=2,
            epochs=2,
            checkpoints_path=check_path)

    m.predict_segmentation(np.zeros((h, w, 3))).shape

    predict_multiple(inp_dir=te_im,
                     checkpoints_path=check_path,
                     out_dir="/tmp")
    predict_multiple(inps=[np.zeros((h, w, 3))] * 3,
                     checkpoints_path=check_path,
                     out_dir="/tmp")

    o = predict(inp=np.zeros((h, w, 3)), checkpoints_path=check_path)
    o.shape
def test_model():
    model_name = "fcn_8"
    h = 224
    w = 256
    n_c = 100
    check_path = tempfile.mktemp()

    m = all_models.model_from_name[model_name](n_c,  input_height=h, input_width=w)

    m.train(train_images=tr_im,
            train_annotations=tr_an,
            steps_per_epoch=2,
            epochs=2,
            checkpoints_path=check_path
            )

    m.predict_segmentation(np.zeros((h, w, 3))).shape

    predict_multiple(
        inp_dir=te_im, checkpoints_path=check_path, out_dir="/tmp")
    predict_multiple(inps=[np.zeros((h, w, 3))]*3,
                     checkpoints_path=check_path, out_dir="/tmp")

    ev = m.evaluate_segmentation( inp_images_dir=te_im  , annotations_dir=te_an )
    assert ev['frequency_weighted_IU'] > 0.01
    print(ev)
    o = predict(inp=np.zeros((h, w, 3)), checkpoints_path=check_path)
    o.shape

    ev = evaluate( inp_images_dir=te_im  , annotations_dir=te_an , checkpoints_path=check_path)
    assert ev['frequency_weighted_IU'] > 0.01
    verify_dataset=True,
    #    load_weights="weights/vgg_unet_1.4" ,
    optimizer_name='adadelta',
    do_augment=True,
    augmentation_name="aug_all",
    checkpoints_path="weights/vgg_unet_1",
    epochs=10)

# Display the model's architecture
model.summary()

# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('vgg_unet_1.h5')

#predict an image from the training data
out = model.predict_segmentation(checkpoints_path="weights/vgg_unet_1",
                                 inp="dataset1/images_prepped_test/43.jpg",
                                 out_fname="newout.png")

from keras_segmentation.predict import predict_multiple

predict_multiple(checkpoints_path="weights/vgg_unet_1",
                 inp_dir="dataset1/images_prepped_test/",
                 out_dir="weights/out/")

#import matplotlib.pyplot as plt
#plt.imshow(out)
#
## evaluating the model
#print(model.evaluate_segmentation( inp_images_dir="dataset1/images_prepped_test/"  , annotations_dir="dataset1/annotations_prepped_test/" ) )
def evaluate_model(image_dir,
                   label_dir=None,
                   checkpoints_path=None,
                   calculate_predicting_indicators=True,
                   output_predicted_result=False,
                   segment_out_predicted_region_from_original_images=False,
                   roi_description='roi',
                   preds=None,
                   batch_process_slice_point=None,
                   target_roi_value=None):
    from keras_segmentation.predict import predict_multiple

    if preds is None:
        print('----------生成模型預測結果----------')
        preds = predict_multiple(checkpoints_path=checkpoints_path,
                                 inp_dir=image_dir)

    preds_batches = []

    if batch_process_slice_point:
        for i in range(len(batch_process_slice_point) + 1):
            if i == 0:
                preds_batches.append(preds[:batch_process_slice_point[i]])
            elif i == len(batch_process_slice_point):
                preds_batches.append(preds[batch_process_slice_point[i - 1]:])
            else:
                preds_batches.append(preds[
                    batch_process_slice_point[i -
                                              1]:batch_process_slice_point[i]])
    else:
        preds_batches.append(preds)

    base_index = 0

    dice_score_list = []
    recall_list = []
    precision_list = []

    globalDice_part = []

    TP = 0
    FP = 0
    FN = 0
    TN = 0

    counter1 = 0
    counter2 = 0
    counter3 = 0

    for idx, preds_batch in enumerate(preds_batches):

        print(f'----------預測結果資料型態轉換(第{idx + 1}批資料)----------')
        preds_batch = np.asarray(preds_batch).astype(np.uint8)

        if calculate_predicting_indicators:
            labels_dicePerCase = []
            preds_dicePerCase = []

            labels_globalDice = []
            preds_globalDice = []

            # 取得該批資料第一位病患的編號
            # 如果影像的命名方法為「病患編號_影像編號」,這裡的病患編號為「病患編號」
            # 如果影像的命名方法為「資料集_病患編號_影像編號」,這裡的病患編號為「資料集_病患編號」
            separator = '_'
            prev_patient_idx = separator.join(
                os.listdir(label_dir)[base_index + 0].split('_')[:-1])

            print(f'----------開始計算各項預測指標(第{idx + 1}批資料)----------')
            for i in range(preds_batch.shape[0]):

                label = cv2.imread(
                    os.path.join(label_dir,
                                 os.listdir(label_dir)[base_index + i]),
                    cv2.IMREAD_GRAYSCALE)
                labels_globalDice.append(label)

                if preds_batch[i].shape != label.shape:
                    pred = cv2.resize(
                        preds_batch[i].copy(),
                        (label.shape[0],
                         label.shape[1]))  # 預測的標記照片大小必須和原始的標記照片一樣
                else:
                    pred = preds_batch[i].copy()
                preds_globalDice.append(pred)

                # 計算混淆矩陣
                if (1 in np.unique(label)) and (1 in np.unique(pred)):
                    TP += 1
                elif (1 not in np.unique(label)) and (1 in np.unique(pred)):
                    FP += 1
                elif (1 in np.unique(label)) and (1 not in np.unique(pred)):
                    FN += 1
                else:
                    TN += 1

                # 取得目前處理的影像對應的病患編號
                separator = '_'
                patient_idx = separator.join(
                    os.listdir(label_dir)[base_index + i].split('_')[:-1])

                if patient_idx != prev_patient_idx:  # 判斷是否到達下一位病患的影像

                    labels_dicePerCase = np.asarray(labels_dicePerCase)
                    preds_dicePerCase = np.asarray(preds_dicePerCase)

                    # print(f'編號為{patient_idx}病患的CT影像張數:{labels_dicePerCase.shape[0]}') #####

                    dice_score_list.append(
                        dice_score(labels_dicePerCase, preds_dicePerCase,
                                   target_roi_value))
                    recall_list.append(
                        recall(labels_dicePerCase, preds_dicePerCase,
                               target_roi_value))
                    precision_list.append(
                        precision(labels_dicePerCase, preds_dicePerCase,
                                  target_roi_value))

                    labels_dicePerCase = []
                    preds_dicePerCase = []

                labels_dicePerCase.append(label)
                preds_dicePerCase.append(pred)

                if i == preds_batch.shape[
                        0] - 1:  # 判斷是否為最後一張的影像,如果是則開始計算最後一位病患的average Dice score per case
                    labels_dicePerCase = np.asarray(labels_dicePerCase)
                    preds_dicePerCase = np.asarray(preds_dicePerCase)

                    # print(f'編號為{patient_idx}病患的CT影像張數:{labels_dicePerCase.shape[0]}') #####

                    dice_score_list.append(
                        dice_score(labels_dicePerCase, preds_dicePerCase,
                                   target_roi_value))
                    recall_list.append(
                        recall(labels_dicePerCase, preds_dicePerCase,
                               target_roi_value))
                    precision_list.append(
                        precision(labels_dicePerCase, preds_dicePerCase,
                                  target_roi_value))

                prev_patient_idx = patient_idx

                counter1 += 1
                if counter1 % 500 == 0:
                    print('目前進度:第' + str(counter1) + '張照片')

            ### 計算 global Dice (part) ###
            labels_globalDice = np.asarray(labels_globalDice)
            preds_globalDice = np.asarray(preds_globalDice)
            globalDice_part.append(
                dice_score(labels_globalDice, preds_globalDice,
                           target_roi_value))

        if output_predicted_result:
            if label_dir:
                save_path = label_dir + '_predicted'
            else:
                save_path = image_dir.replace(
                    image_dir.split('\\')[-1],
                    'annotations') + '_' + roi_description + '_predicted'
            if not os.path.exists(save_path):
                os.makedirs(save_path)
                print('-----建立新資料夾:' + save_path + '-----')

            print(f'---------開始輸出模型預測結果(第{idx + 1}批資料)----------')

            for i in range(preds_batch.shape[0]):
                image_ori = cv2.imread(
                    os.path.join(image_dir,
                                 os.listdir(image_dir)[base_index + i]),
                    cv2.IMREAD_GRAYSCALE)
                image_ori_name = os.listdir(image_dir)[base_index + i]

                if preds_batch[i].shape != image_ori.shape:
                    pred = cv2.resize(preds_batch[i],
                                      (image_ori.shape[0], image_ori.shape[1]))
                else:
                    pred = preds_batch[i]

                cv2.imwrite(os.path.join(save_path, image_ori_name), pred)

                counter2 += 1
                if counter2 % 500 == 0:
                    print('目前進度:第' + str(counter2) + '張照片')

        if segment_out_predicted_region_from_original_images:
            save_path = image_dir + '_only_containing_predicted_roi_' + roi_description
            if not os.path.exists(save_path):
                os.makedirs(save_path)
                print('-----建立新資料夾:' + save_path + '-----')

            print(f'----------開始生成並輸出只包含模型預測區域的圖片(第{idx + 1}批資料)----------')

            for i in range(preds_batch.shape[0]):
                image_ori = cv2.imread(
                    os.path.join(image_dir,
                                 os.listdir(image_dir)[base_index + i]),
                    cv2.IMREAD_GRAYSCALE)

                if preds_batch[i].shape != image_ori.shape:
                    pred = cv2.resize(preds_batch[i],
                                      (image_ori.shape[0], image_ori.shape[1]))
                else:
                    pred = preds_batch[i]

                image_pred_roi_region = pred * image_ori  # 只保留原始圖片中模型預測區域的位置(其他區域視為像素等於0的背景)
                cv2.imwrite(
                    os.path.join(save_path,
                                 os.listdir(image_dir)[base_index + i]),
                    image_pred_roi_region)

                counter3 += 1
                if counter3 % 500 == 0:
                    print('目前進度:第' + str(counter3) + '張照片')

        base_index += len(preds_batch)

    ### 計算 global Dice ###
    if calculate_predicting_indicators:
        globalDice = np.mean(globalDice_part)

    print(f'total case number: {len(preds)}')

    if calculate_predicting_indicators:
        return np.mean(dice_score_list), np.mean(recall_list), np.mean(
            precision_list
        ), globalDice, preds, dice_score_list, recall_list, precision_list, TP, FP, FN, TN
    else:
        return preds
import tensorflow as tf
from keras_segmentation.predict import predict_multiple
from keras_segmentation.predict import model_from_checkpoint_path

gpu_devices = tf.config.experimental.list_physical_devices('GPU')
for device in gpu_devices: tf.config.experimental.set_memory_growth(device, True)

print(f"[INFO] Predicting on test set mask..")

pdr=predict_multiple(
	checkpoints_path=r"C:\Users\matte\PycharmProjects\ecographic_breast_nn\segnet\checkpoints\psp_unet",
	inp_dir=r"D:\FISICA MEDICA\radiomics_eco\Dataset_BUSI_with_GT\segnet\images_val",
	out_dir=r"C:\Users\matte\PycharmProjects\ecographic_breast_nn\segnet\outputs"
)

model=model_from_checkpoint_path(r"C:\Users\matte\PycharmProjects\ecographic_breast_nn\segnet\checkpoints\psp_unet")


print(f"[INFO] Evaluating model..")
print(model.evaluate_segmentation( inp_images_dir=r"D:\FISICA MEDICA\radiomics_eco\Dataset_BUSI_with_GT\segnet\images_val"  , annotations_dir=r"D:\FISICA MEDICA\radiomics_eco\Dataset_BUSI_with_GT\segnet\masks_val" ) )
from keras_segmentation.predict import predict_multiple, evaluate, model_from_checkpoint_path
from keras_segmentation.data_utils.visualize_dataset import visualize_segmentation_dataset
import cv2

checkpoints_path = []
checkpoints_path.append("./saveModel/vgg_unet_1")
checkpoints_path.append("./saveModel/vgg_pspnet_1")
checkpoints_path.append("./saveModel/vgg_segnet_1")
checkpoints_path.append("./saveModel/fcn_32_vgg_1")

for ia in checkpoints_path:
	predict_multiple(
		checkpoints_path=ia,
		inp_dir="./demo/test",
		out_dir="./demo/predict",
		show_legends = True,
		class_names = ["sky", "building", "pole", "road", "sidewalk", "vegetation", "traffic_light", "fence", "car", "person", "rider", "static"],
		# overlay_img = True, #Chồng 2 ảnh lên nhau
		# prediction_width = 1024,
		# prediction_height = 768
	)

# view summary
# for ia in checkpoints_path:
# 	model = model_from_checkpoint_path(ia)
# 	print(model.summary())

# visuallize
# visualize_segmentation_dataset(images_path="./test/example_dataset/images_prepped_train",segs_path="./test/example_dataset/annotations_prepped_train",n_classes=6,no_show=False)
# visualize_segmentation_dataset(images_path="demo/images_preped_test_small/",segs_path="demo/annotations_preped_test_small/",n_classes=6,no_show=True)
# for seg_img, seg_path in visualize_segmentation_dataset(images_path="demo/images_preped_test_small/",segs_path="demo/annotations_preped_test_small/",n_classes=6,no_show=True):
# 	cv2.imwrite(f"demo/ground_truth/{seg_path.split('/')[-1]}", seg_img)