def main(argv):
    try:
        opts, args = getopt.getopt(
            sys.argv[1:], 'm:p:e:i:t:w:',
            ['mri=', 'pet=', 'epoch=', 'id=', 'th=', 'weight_PGWC='])
    except getopt.GetoptError:
        usage()
        sys.exit()

    for opt, arg in opts:
        if opt in ['-m', '--mri']:
            dir_mri = arg
        elif opt in ['-p', '--pet']:
            dir_pet = arg
        elif opt in ['-e', '--epoch']:
            n_epoch = int(arg) + 1
        elif opt in ['-i', '--id']:
            model_id = arg
        elif opt in ['-t', '--th']:
            MRI_TH = arg
        elif opt in ['-w', '--weight_PGWC']:
            W_PGWC = arg
        else:
            print("Error: invalid parameters")

    dir_mri = './/files//' + dir_mri + '_mri.nii.gz'
    dir_pet = './/files//' + dir_pet + '_pet.nii.gz'

    # print('Number of arguments:', len(argv), 'arguments.')
    # print('Argument List:', str(argv))
    time_stamp = datetime.datetime.now().strftime("-%Y-%m-%d-%H-%M")
    # print("------------------------------------------------------------------")
    # print("MRI_dir: ", dir_mri)
    # print("PET_dir: ", dir_pet)
    # print("n_EPOCH: ", n_epoch)
    print("MODEL_ID: ", model_id + time_stamp)
    # print("------------------------------------------------------------------")
    # print("Build a U-Net:")

    GL_set_value("MODEL_ID", model_id + time_stamp)
    GL_set_value("MRI_TH", MRI_TH)
    GL_set_value("W_PGWC", W_PGWC)

    model, opt, loss, callbacks_list, conf = set_configuration(n_epoch=n_epoch,
                                                               flag_aug=False)
    data_mri, data_pet = set_dataset(dir_mri=dir_mri, dir_pet=dir_pet)
    X, Y = data_pre_PVC(data_mri=data_mri, data_pet=data_pet)
    # model.summary()

    model.compile(opt, loss)

    w_train(model=model, X=X, Y=Y, n_epoch=n_epoch)

    del model
    del data_mri
    del data_pet
    gc.collect()
Exemple #2
0
def set_dataset(dir_mri, dir_pet):
    mri_file = nib.load(dir_mri)
    pet_file = nib.load(dir_pet)

    header = pet_file.header
    affine = pet_file.affine
    GL_set_value("header", header)
    GL_set_value("affine", affine)

    data_mri = mri_file.get_fdata()
    data_pet = pet_file.get_fdata()

    print("MRI_img shape:", data_mri.shape)
    print("PET_img shape:", data_pet.shape)

    return data_mri, data_pet
Exemple #3
0
def data_pre_breast_practical(data_mri_water, data_mri_fat, data_pet):

    IMG_ROWS = GL_get_value("IMG_ROWS")
    IMG_COLS = GL_get_value("IMG_COLS")
    IDX_SLICE = GL_get_value("IDX_SLICE")

    X = np.zeros((1, IMG_ROWS, IMG_COLS, 3))
    Y = np.zeros((1, IMG_ROWS, IMG_COLS, 3))

    GL_set_value("FA_NORM", np.amax(data_pet))

    img_p = data_pet[:, :, IDX_SLICE]
    mask_pet = np.asarray([img_p > 350]).reshape((256, 256)).astype(int)

    # data_pet = np.divide(data_pet, np.amax(data_pet))
    # data_mri_water = np.divide(data_mri_water, np.amax(data_mri_water))
    # data_mri_fat = np.divide(data_mri_fat, np.amax(data_mri_fat))

    img_w = data_mri_water[:, :, IDX_SLICE]
    img_f = data_mri_fat[:, :, IDX_SLICE]

    img_sum = img_f + img_w + 1e-6
    mask_sum = np.asarray([img_sum > 150]).reshape((256, 256)).astype(int)
    mask = mask_pet * mask_sum
    mask = sm.opening(mask, sm.disk(5))
    mask = sm.closing(mask, sm.square(5))

    # img_sum = img_sum / np.amax(img_sum)

    img_ff = np.divide(img_f, img_sum) * mask
    # img_f = img_f / np.amax(img_f)
    # img_ff[img_ff <= 0.8] = 0

    X[0, :, :, 0] = np.divide(img_p, np.amax(data_pet))
    X[0, :, :, 2] = np.divide(img_w, np.amax(data_mri_water))
    X[0, :, :, 1] = np.divide(img_w, np.amax(data_mri_water))
    print(img_ff.shape)
    GL_set_value("img_ff", img_ff)
    Y = X

    print("X shape:", X.shape)
    print("Y shape:", Y.shape)

    return X, Y
Exemple #4
0
def data_pre_breast_m2p(data_mri_water, data_mri_fat, data_pet):

    IMG_ROWS = GL_get_value("IMG_ROWS")
    IMG_COLS = GL_get_value("IMG_COLS")
    IDX_SLICE = GL_get_value("IDX_SLICE")

    X = np.zeros((1, IMG_ROWS, IMG_COLS, 1))
    Y = np.zeros((1, IMG_ROWS, IMG_COLS, 1))

    GL_set_value("FA_NORM", np.amax(data_pet))

    # input
    img_p = data_pet[:, :, IDX_SLICE]
    Y[0, :, :, 0] = np.divide(img_p, np.amax(data_pet))

    # water/fat fraction
    img_w = data_mri_water[:, :, IDX_SLICE]
    img_f = data_mri_fat[:, :, IDX_SLICE]
    img_sum = img_f + img_w + 1e-6

    X[0, :, :, 0] = np.divide(img_sum, np.amax(img_sum))

    # mask
    mask_pet = np.asarray([img_p > 350]).reshape((256, 256)).astype(int)
    mask_sum = np.asarray([img_sum > 150]).reshape((256, 256)).astype(int)
    mask = mask_pet * mask_sum
    mask = sm.opening(mask, sm.disk(5))
    mask = sm.closing(mask, sm.square(5))

    # water/fat fraction
    img_ff = np.divide(img_f, img_sum) * mask
    img_wf = np.divide(img_w, img_sum) * mask
    # img_ff[img_ff <= 0.8] = 0

    GL_set_value("img_ff", img_ff)
    GL_set_value("img_wf", img_wf)
    GL_set_value("mask_pet", mask_pet)

    print("X shape:", X.shape)
    print("Y shape:", Y.shape)

    return X, Y
Exemple #5
0
def set_dataset_brest(dir_mri_water, dir_mri_fat, dir_pet):

    mri_water_file = nib.load(dir_mri_water)
    mri_fat_file = nib.load(dir_mri_fat)
    pet_file = nib.load(dir_pet)

    header = pet_file.header
    affine = pet_file.affine
    GL_set_value("header", header)
    GL_set_value("affine", affine)

    data_water = mri_water_file.get_fdata()
    data_fat = mri_fat_file.get_fdata()
    data_pet = pet_file.get_fdata()

    print("WATER_img shape:", data_water.shape)
    print("FAT_img shape:", data_fat.shape)
    print("PET_img shape:", data_pet.shape)

    return data_water, data_fat, data_pet
Exemple #6
0
def data_pre_breast(data_mri_water, data_mri_fat, data_pet):

    IMG_ROWS = GL_get_value("IMG_ROWS")
    IMG_COLS = GL_get_value("IMG_COLS")
    IDX_SLICE = GL_get_value("IDX_SLICE")

    X = np.zeros((1, IMG_ROWS, IMG_COLS, 3))
    Y = np.zeros((1, IMG_ROWS, IMG_COLS, 3))

    GL_set_value("FA_NORM", np.amax(data_pet))

    img_p = data_pet[:, :, IDX_SLICE]
    mask = np.asarray([img_p > 350]).reshape((256, 256)).astype(int)

    data_pet = np.divide(data_pet, np.amax(data_pet))
    data_mri_water = np.divide(data_mri_water, np.amax(data_mri_water))
    data_mri_fat = np.divide(data_mri_fat, np.amax(data_mri_fat))

    X[0, :, :, 0] = data_pet[:, :, IDX_SLICE]
    img_w = data_mri_water[:, :, IDX_SLICE]
    img_f = data_mri_fat[:, :, IDX_SLICE]

    img_sum = img_f + img_w + 1e-6
    # img_sum = img_sum / np.amax(img_sum)

    img_ff = np.divide(img_f, img_sum) * mask
    # img_f = img_f / np.amax(img_f)
    #img_f[img_f <= 0.95] = 0

    X[0, :, :, 2] = img_w
    X[0, :, :, 1] = img_w
    print(img_ff.shape)
    GL_set_value("img_ff", img_ff)
    Y = X

    print("X shape:", X.shape)
    print("Y shape:", Y.shape)

    return X, Y
Exemple #7
0
def data_pre_PVC(data_mri, data_pet):

    IMG_ROWS = GL_get_value("IMG_ROWS")
    IMG_COLS = GL_get_value("IMG_COLS")
    IDX_SLICE = GL_get_value("IDX_SLICE")
    # FA_NORM = GL_get_value("FA_NORM")

    FA_NORM = np.amax(data_pet[:, :, IDX_SLICE])
    GL_set_value('FA_NORM', FA_NORM)
    mri_th = float(GL_get_value("mri_th"))

    X = np.zeros((1, IMG_ROWS, IMG_COLS, 4))
    Y = np.zeros((1, IMG_ROWS, IMG_COLS, 4))
    Z = np.zeros((1, IMG_ROWS, IMG_COLS, 4))

    data_pet = np.divide(data_pet, FA_NORM)

    Z[0, :, :, 0] = data_pet[:, :, IDX_SLICE] <= mri_th
    Z[0, :, :, 1] = data_mri[:, :, IDX_SLICE] == 3
    Z[0, :, :, 2] = data_mri[:, :, IDX_SLICE] != 0
    Z[0, :, :, 2] = Z[0, :, :, 2].astype(bool).astype(int)
    Z[0, :, :, 3] = data_pet[:, :, IDX_SLICE] > mri_th

    X[0, :, :, 0] = data_pet[:, :, IDX_SLICE] * Z[0, :, :, 2]  # PET
    X[0, :, :, 1] = data_mri[:, :, IDX_SLICE] == 1  # CSF
    X[0, :, :, 2] = data_mri[:, :, IDX_SLICE] == 2  # gray matter
    X[0, :, :, 2] = (Z[0, :, :, 3] + X[0, :, :, 2]).astype(bool).astype(
        int)  # gray matter
    X[0, :, :, 3] = Z[0, :, :, 0] * Z[0, :, :, 1]  # white matter

    # if GL_get_value("flag_reg"):
    #     Y = X.flatten()
    # else:
    #     Y = X

    Y = X
    del Z
    gc.collect()
    # print("X shape:", X.shape)
    # print("Y shape:", Y.shape)
    return X, Y
Exemple #8
0
def main():
    parser = argparse.ArgumentParser(
        description=
        '''This is a beta script for Partial Volume Correction in PET/MRI system. ''',
        epilog="""All's well that ends well.""")
    parser.add_argument('--mri_water',
                        metavar='',
                        type=str,
                        default="subj02",
                        help='Name of MRI water subject.(subj02)')
    parser.add_argument('--mri_fat',
                        metavar='',
                        type=str,
                        default="subj02",
                        help='Name of MRI fat subject.(subj02)')
    parser.add_argument('-p',
                        '--pet',
                        metavar='',
                        type=str,
                        default="subj02",
                        help='Name of PET subject.(subj02)')
    parser.add_argument('-e',
                        '--epoch',
                        metavar='',
                        type=int,
                        default=2000,
                        help='Number of epoches of training.(2000)')
    parser.add_argument('-i',
                        '--id',
                        metavar='',
                        type=str,
                        default="eeVee",
                        help='ID of the current model.(eeVee)')
    parser.add_argument('--w_pet',
                        metavar='',
                        type=int,
                        default=7000,
                        help='Weight of PET')
    parser.add_argument('--w_water',
                        metavar='',
                        type=int,
                        default=100,
                        help='Weight of water MRI')
    parser.add_argument('--w_fat',
                        metavar='',
                        type=int,
                        default=1,
                        help='Weight of fat MRI')
    parser.add_argument('--flag_BN',
                        metavar='',
                        type=bool,
                        default=True,
                        help='Flag of BatchNormlization(True)')
    parser.add_argument('--flag_Dropout',
                        metavar='',
                        type=bool,
                        default=True,
                        help='Flag of Dropout(True)')
    parser.add_argument('--flag_reg',
                        metavar='',
                        type=bool,
                        default=False,
                        help='Flag of regularizer(False)')
    parser.add_argument('--type_wr',
                        metavar='',
                        type=str,
                        default='None',
                        help='Flag of weight regularizer(l2/l1)')
    parser.add_argument('--type_yr',
                        metavar='',
                        type=str,
                        default='None',
                        help='Flag of y regularizer(l2/l1)')
    parser.add_argument('--para_wr',
                        metavar='',
                        type=float,
                        default=0.01,
                        help='Para of weight regularizer(0.01)')
    parser.add_argument('--para_yr',
                        metavar='',
                        type=float,
                        default=0.01,
                        help='Para of y regularizer(0.01)')
    parser.add_argument('--n_filter',
                        metavar='',
                        type=int,
                        default=64,
                        help='The initial filter number')
    parser.add_argument('--depth',
                        metavar='',
                        type=int,
                        default=4,
                        help='The depth of U-Net')
    parser.add_argument('--gap_flash',
                        metavar='',
                        type=int,
                        default=100,
                        help='How many epochs between two flash shoot')
    parser.add_argument('--flag_whole',
                        metavar='',
                        type=bool,
                        default=False,
                        help='Whether process the whole PET image')
    parser.add_argument('--idx_slice',
                        metavar='',
                        type=int,
                        default=47,
                        help='The idx to be processed.')
    parser.add_argument('--flag_smooth',
                        metavar='',
                        type=bool,
                        default=False,
                        help='Flag of Smooth loss function')

    args = parser.parse_args()

    dir_mri_water = './/files//' + args.mri_water + '_water.nii'
    dir_mri_fat = './/files//' + args.mri_fat + '_fat.nii'
    dir_pet = './/files//' + args.pet + '_pet.nii'
    n_epoch = args.epoch + 1

    time_stamp = datetime.datetime.now().strftime("-%Y-%m-%d-%H-%M")
    GL_set_value("MODEL_ID", args.id + time_stamp)
    GL_set_value("w_pet", args.w_pet)
    GL_set_value("w_water", args.w_water)
    GL_set_value("w_fat", args.w_fat)
    GL_set_value("flag_BN", args.flag_BN)
    GL_set_value("flag_Dropout", args.flag_Dropout)
    GL_set_value("flag_reg", args.flag_reg)
    GL_set_value("flag_wr", args.type_wr)
    GL_set_value("flag_yr", args.type_yr)
    GL_set_value("para_wr", args.para_wr)
    GL_set_value("para_yr", args.para_yr)
    GL_set_value("n_filter", args.n_filter)
    GL_set_value("depth", args.depth)
    GL_set_value("gap_flash", args.gap_flash)
    GL_set_value("flag_whole", args.flag_whole)
    GL_set_value("IDX_SLICE", args.idx_slice)
    GL_set_value("flag_smooth", args.flag_smooth)

    # model establishment
    if args.flag_whole:
        GL_set_value("MODEL_ID", args.id)

    model, opt, loss, callbacks_list, conf = set_configuration(n_epoch=n_epoch,
                                                               flag_aug=False)
    # add_regularizer(model)
    data_mri_water, data_mri_fat, data_pet = set_dataset_brest(
        dir_mri_water=dir_mri_water, dir_mri_fat=dir_mri_fat, dir_pet=dir_pet)

    GL_set_value("IDX_SLICE", args.idx_slice)

    X, Y = data_pre_breast_p2p(data_mri_water=data_mri_water,
                               data_mri_fat=data_mri_fat,
                               data_pet=data_pet)
    model.summary()
    model.compile(opt, loss)

    if args.flag_whole:
        w_pred_breast(model=model, X=X, Y=Y, n_epoch=n_epoch)
        print('The slice has been completed. ' + str(args.idx_slice))
    else:
        w_train_breast(model=model, X=X, Y=Y, n_epoch=n_epoch)

    save_all()
    del model
    del data_mri_water
    del data_mri_fat
    del data_pet
    gc.collect()
Exemple #9
0
import gc
import os
# import sys
# import getopt
import datetime
import argparse
import numpy as np
from config.main_config import set_configuration
from data.load_data import set_dataset_brest
from data.set_X_Y import data_pre_breast, data_pre_breast_practical, data_pre_breast_p2p
from run.run_pvc import w_train_breast, w_pred_breast
from GL.w_global import GL_set_value, GL_get_value
from eval.output import w_output
from run.save_para import save_all

GL_set_value("IMG_ROWS", 256)
GL_set_value("IMG_COLS", 256)
GL_set_value("IMG_DEPT", 80)

np.random.seed(591)


def usage():
    print("Error in input argv")


def main():
    parser = argparse.ArgumentParser(
        description=
        '''This is a beta script for Partial Volume Correction in PET/MRI system. ''',
        epilog="""All's well that ends well.""")
Exemple #10
0
import os
# import sys
# import getopt
import datetime
import argparse
import numpy as np
from config.main_config import set_configuration
from data.load_data import set_dataset
from data.set_X_Y import data_pre_PVC, data_pre_seg
from run.run_pvc import w_train, w_pred
from GL.w_global import GL_set_value, GL_get_value
from eval.output import w_output
from run.save_para import save_all


GL_set_value("IMG_ROWS", 512)
GL_set_value("IMG_COLS", 512)
GL_set_value("IMG_DEPT", 284)
GL_set_value("FA_NORM", 35000.0)

np.random.seed(591)

def usage():
    print("Error in input argv")


def main():
    parser = argparse.ArgumentParser(
        description='''This is a beta script for Partial Volume Correction in PET/MRI system. ''',
        epilog="""All's well that ends well.""")
    parser.add_argument('-m', '--mri', metavar='', type=str, default="subj01",
import gc
import sys
import getopt
import datetime
import numpy as np
from config.main_config import set_configuration
from data.load_data import set_dataset
from data.set_X_Y import data_pre_PVC
from run.run_pvc import w_train
from GL.w_global import GL_set_value

global IMG_ROWS, IMG_COLS
global IDX_SLICE, FA_NORM

GL_set_value("IMG_ROWS", 512)
GL_set_value("IMG_COLS", 512)
GL_set_value("IDX_SLICE", 142)
GL_set_value("FA_NORM", 35000.0)

np.random.seed(591)


def usage():
    print("Error in input argv")


def main(argv):
    try:
        opts, args = getopt.getopt(
            sys.argv[1:], 'm:p:e:i:t:w:',