コード例 #1
0
def test_all_nets(out_dir, Log, which_data):
    data_path_AMUC = "/exports/lkeb-hpc/syousefi/Data/ASL2PET_high_res/AMUC_high_res/"
    data_path_LUMC = "/exports/lkeb-hpc/syousefi/Data/ASL2PET_high_res/LUMC_high_res/"

    _rd = _read_data(data_path_AMUC, data_path_LUMC)
    fold = 2
    train_data, validation_data, test_data = _rd.read_data_path(average_no=0,
                                                                fold=fold)

    if which_data == 1:
        data = validation_data
    elif which_data == 2:
        data = test_data
    elif which_data == 3:
        data = train_data

    asl_plchld = tf.placeholder(tf.float32,
                                shape=[None, asl_size, asl_size, 1])
    t1_plchld = tf.placeholder(tf.float32, shape=[None, asl_size, asl_size, 1])
    pet_plchld = tf.placeholder(tf.float32,
                                shape=[None, pet_size, pet_size, 1])
    asl_out_plchld = tf.placeholder(tf.float32,
                                    shape=[None, pet_size, pet_size, 1])
    hybrid_training_flag = tf.placeholder(tf.bool, name='hybrid_training_flag')
    ave_loss_vali = tf.placeholder(tf.float32)
    residual_attention_map = tf.placeholder(
        tf.float32, shape=[None, asl_size, asl_size, 1])
    is_training = tf.placeholder(tf.bool, name='is_training')
    is_training_bn = tf.placeholder(tf.bool, name='is_training_bn')

    msdensnet = multi_stage_densenet()
    asl_y, pet_y, new_att_map = msdensnet.multi_stage_densenet(
        asl_img=asl_plchld,
        t1_img=t1_plchld,
        pet_img=pet_plchld,
        hybrid_training_flag=hybrid_training_flag,
        input_dim=asl_size,
        is_training=is_training,
        config=config,
        residual_attention_map=residual_attention_map)
    alpha = .84
    with tf.name_scope('cost'):
        ssim_asl = tf.reduce_mean(
            1 - SSIM(x1=asl_out_plchld, x2=asl_y, max_val=34.0)[0])
        loss_asl = alpha * ssim_asl + (1 - alpha) * tf.reduce_mean(
            huber(labels=asl_out_plchld, logit=asl_y))

        ssim_pet = tf.reduce_mean(
            1 - SSIM(x1=pet_plchld, x2=pet_y, max_val=2.1)[0])
        loss_pet = alpha * ssim_pet + (1 - alpha) * tf.reduce_mean(
            huber(labels=pet_plchld, logit=pet_y))

        cost_withpet = tf.reduce_mean(loss_asl + loss_pet)

        cost_withoutpet = loss_asl

        mse = MSE(x1=pet_plchld, x2=pet_y)
        psnr = PSNR(x1=pet_plchld, x2=pet_y)

    sess = tf.Session()
    saver = tf.train.Saver()
    parent_path = '/exports/lkeb-hpc/syousefi/Code/'
    chckpnt_dir = parent_path + Log + 'unet_checkpoints/'
    ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
    saver.restore(sess, ckpt.model_checkpoint_path)
    _meas = _measure()
    copyfile(
        './test_semisupervised_multitast_hr_rest_residual_skip_attention.py',
        parent_path + Log + out_dir +
        'test_semisupervised_multitast_hr_rest_residual_skip_attention.py')

    _image_class = image_class(data,
                               bunch_of_images_no=1,
                               is_training=0,
                               inp_size=asl_size,
                               out_size=pet_size)
    list_ssim_pet = []
    list_ssim_asl = []
    list_name = []
    list_ssim_NC = []
    list_mse_NC = []
    list_psnr_NC = []
    list_ssim_HC = []
    list_mse_HC = []
    list_psnr_HC = []
    for scan in range(len(data)):
        ss = str(data[scan]['asl']).split("/")
        imm = _image_class.read_image(data[scan])
        try:
            # print(parent_path + Log + out_dir + ss[-3])
            os.mkdir(parent_path + Log + out_dir + ss[-3])

        except:
            a = 1
        # print(parent_path + Log + out_dir + ss[-3] + '/' + ss[-1].split(".")[0].split("ASL_")[1])
        try:
            os.mkdir(parent_path + Log + out_dir + ss[-3] + '/' +
                     ss[-1].split(".")[0].split("ASL_")[1])
        except:
            try:
                os.mkdir(parent_path + Log + out_dir + ss[-3] + '/' +
                         ss[-1].split(".")[0])
            except:
                a = 1
        for img_indx in range(np.shape(imm[3])[0]):
            print('img_indx:%s' % (img_indx))

            name = ss[-3] + '_' + ss[-2] + '_' + str(img_indx)
            # name = (ss[10] + '_' + ss[11] + '_' + ss[12].split('%')[0]).split('_CT')[0]
            tic = time.time()
            t1 = imm[3][img_indx,
                        int(imm[3].shape[-1] / 2) - int(asl_size / 2) -
                        1:int(imm[3].shape[-1] / 2) + int(asl_size / 2),
                        int(imm[3].shape[-1] / 2) - int(asl_size / 2) -
                        1:int(imm[3].shape[-1] / 2) + int(asl_size / 2)]
            t1 = t1[np.newaxis, ..., np.newaxis]
            asl = imm[4][np.newaxis, img_indx,
                         int(imm[3].shape[-1] / 2) - int(asl_size / 2) -
                         1:int(imm[3].shape[-1] / 2) + int(asl_size / 2),
                         int(imm[3].shape[-1] / 2) - int(asl_size / 2) -
                         1:int(imm[3].shape[-1] / 2) + int(asl_size / 2),
                         np.newaxis]
            if np.size(imm[5]) > 1:
                pet = imm[5][np.newaxis, img_indx,
                             int(imm[3].shape[-1] / 2) - int(pet_size / 2) -
                             1:int(imm[3].shape[-1] / 2) + int(pet_size / 2),
                             int(imm[3].shape[-1] / 2) - int(pet_size / 2) -
                             1:int(imm[3].shape[-1] / 2) + int(pet_size / 2),
                             np.newaxis]
                hybrid_training_f = True
            else:
                hybrid_training_f = False
                pet = np.reshape([None] * pet_size * pet_size,
                                 [pet_size, pet_size])
                pet = pet[..., np.newaxis]
                pet = pet[np.newaxis, ...]

            # if hybrid_training_f:
            [loss, psnr1, mse1, pet_out, asl_out] = sess.run(
                [ssim_pet, psnr, mse, pet_y, asl_y],
                feed_dict={
                    asl_plchld:
                    asl,
                    t1_plchld:
                    t1,
                    pet_plchld:
                    pet,
                    asl_out_plchld:
                    asl[:,
                        int(asl_size / 2) - int(pet_size / 2) -
                        1:int(asl_size / 2) + int(pet_size / 2),
                        int(asl_size / 2) - int(pet_size / 2) -
                        1:int(asl_size / 2) + int(pet_size / 2), :],
                    is_training:
                    False,
                    ave_loss_vali:
                    -1,
                    is_training_bn:
                    False,
                    hybrid_training_flag:
                    False,
                    residual_attention_map:
                    (np.ones([1, asl_size, asl_size, 1]))
                })
            # if hybrid_training_f:
            ssim = 1 - loss
            list_ssim_pet.append(ssim)
            # else:
            # [loss,asl_out  ] = sess.run([ssim_asl,asl_y  ],
            #                                 feed_dict={asl_plchld: asl,
            #                                            t1_plchld: t1,
            #                                            pet_plchld: pet,
            #                                            asl_out_plchld: asl[:,
            #                                  int(asl_size / 2) - int(pet_size / 2) - 1:
            #                                  int(asl_size / 2) + int(pet_size / 2),
            #                                  int(asl_size / 2) - int(pet_size / 2) - 1:
            #                                  int(asl_size / 2) + int(pet_size / 2), :],
            #                                            is_training: False,
            #                                            ave_loss_vali: -1,
            #                                            is_training_bn: False,
            #                                            hybrid_training_flag:True })
            # ssim = 1 - loss
            # list_ssim_asl.append(ssim)

            try:
                str_nm = (ss[-3] + '_' +
                          ss[-1].split(".")[0].split("ASL_")[1] + '_t1_' +
                          name)
            except:
                str_nm = (ss[-3] + '_' + ss[-1].split(".")[0] + '_t1_' + name)

            if 'HN' in str_nm:
                list_ssim_NC.append(ssim)
                list_psnr_NC.append(psnr1)
                list_mse_NC.append(mse1)
            elif 'HY' in str_nm:
                list_ssim_HC.append(ssim)
                list_psnr_HC.append(psnr1)
                list_mse_HC.append(mse1)

            try:
                list_name.append(ss[-3] + '_' +
                                 ss[-1].split(".")[0].split("ASL_")[1] +
                                 '_t1_' + name)
                nm_fig = parent_path + Log + out_dir + ss[-3] + '/' + ss[
                    -1].split(".")[0].split("ASL_")[1]
            except:
                list_name.append(ss[-3] + '_' + ss[-1].split(".")[0] + '_t1_' +
                                 name)
                nm_fig = parent_path + Log + out_dir + ss[-3] + '/' + ss[
                    -1].split(".")[0]
            print(list_name[img_indx], ': ', ssim, ',PSNR: ', psnr1, ',MSE: ',
                  mse1)

            sitk.WriteImage(sitk.GetImageFromArray(np.squeeze(t1)),
                            nm_fig + '/t1_' + name + '.mha')
            sitk.WriteImage(sitk.GetImageFromArray(np.squeeze(asl)),
                            nm_fig + '/asl_' + name + '_' + '.mha')
            if hybrid_training_f:
                sitk.WriteImage(sitk.GetImageFromArray(np.squeeze(pet)),
                                nm_fig + '/pet_' + name + '_' + '.mha')
            sitk.WriteImage(
                sitk.GetImageFromArray(np.squeeze(pet_out)),
                nm_fig + '/res_pet' + name + '_' + str(ssim) + '.mha')
            sitk.WriteImage(
                sitk.GetImageFromArray(np.squeeze(asl_out)),
                nm_fig + '/res_asl' + name + '_' + str(ssim) + '.mha')

            elapsed = time.time() - tic

    df = pd.DataFrame(list_ssim_pet, columns=pd.Index(['ssim'],
                                                      name='Genus')).round(2)

    a = {
        'SSIM_HC': list_ssim_HC,
        'SSIM_NC': list_ssim_NC,
        'MSE_HC': list_mse_HC,
        'MSE_NC': list_mse_NC,
        'PSNR_HC': list_psnr_HC,
        'PSNR_NC': list_psnr_NC
    }
    df2 = pd.DataFrame.from_dict(a, orient='index')
    df2.transpose()

    # Create a Pandas Excel writer using XlsxWriter as the engine.
    writer = pd.ExcelWriter(parent_path + Log + out_dir + '/all_ssim.xlsx',
                            engine='xlsxwriter')
    writer2 = pd.ExcelWriter(parent_path + Log + out_dir + '/all_ssim.xlsx',
                             engine='xlsxwriter')
    # Convert the dataframe to an XlsxWriter Excel object.
    df.to_excel(writer, sheet_name='Sheet1')
    df2.to_excel(writer2, sheet_name='Sheet2')
    # Close the Pandas Excel writer and output the Excel file.
    writer.save()
    writer2.save()

    print(parent_path + Log + out_dir + '/all_ssim.xlsx')
def test_all_nets(out_dir, Log, which_data):
    data_path_AMUC = "/exports/lkeb-hpc/syousefi/Data/ASL2PET_high_res/AMUC_high_res/"
    data_path_LUMC = "/exports/lkeb-hpc/syousefi/Data/ASL2PET_high_res/LUMC_high_res/"
    
    _rd = _read_data(data_path_AMUC,data_path_LUMC)

    train_data, validation_data, test_data = _rd.read_data_path(0)

    if which_data == 1:
        data = validation_data
    elif which_data == 2:
        data = test_data
    elif which_data == 3:
        data = train_data

    asl_plchld = tf.placeholder(tf.float32, shape=[None, asl_size, asl_size, 1])
    t1_plchld = tf.placeholder(tf.float32, shape=[None, asl_size, asl_size, 1])
    pet_plchld = tf.placeholder(tf.float32, shape=[None, pet_size, pet_size, 1])
    asl_out_plchld = tf.placeholder(tf.float32, shape=[None, pet_size, pet_size, 1])
    hybrid_training_flag = tf.placeholder(tf.bool, name='hybrid_training_flag')
    ave_loss_vali = tf.placeholder(tf.float32)

    is_training = tf.placeholder(tf.bool, name='is_training')
    is_training_bn = tf.placeholder(tf.bool, name='is_training_bn')

    msdensnet = multi_stage_densenet()
    asl_y1,asl_y2, pet_y = msdensnet.multi_stage_densenet(asl_img=asl_plchld,
                                                  t1_img=t1_plchld,
                                                  pet_img=pet_plchld,
                                                  hybrid_training_flag=hybrid_training_flag,
                                                  input_dim=asl_size,
                                                  is_training=is_training,
                                                  config=config,
                                                  )
    alpha = .84
    with tf.name_scope('cost'):
        ssim_asl1 = tf.reduce_mean(1 - SSIM(x1=asl_out_plchld, x2=asl_y1, max_val=34.0)[0])
        loss_asl1 = alpha * ssim_asl1 + (1 - alpha) * tf.reduce_mean(huber(labels=asl_out_plchld, logit=asl_y1))

        ssim_asl = tf.reduce_mean(1 - SSIM(x1=asl_out_plchld, x2=asl_y2, max_val=34.0)[0])
        loss_asl = alpha * ssim_asl + (1 - alpha) * tf.reduce_mean(huber(labels=asl_out_plchld, logit=asl_y2))

        ssim_pet = tf.reduce_mean(1 - SSIM(x1=pet_plchld, x2=pet_y, max_val=2.1)[0])
        loss_pet = alpha * ssim_pet + (1 - alpha) * tf.reduce_mean(huber(labels=pet_plchld, logit=pet_y))

    cost_withpet = tf.reduce_mean(loss_asl + loss_asl1 + loss_pet)

    cost_withoutpet = loss_asl + loss_asl1

    sess = tf.Session()
    saver = tf.train.Saver()
    parent_path = '/exports/lkeb-hpc/syousefi/Code/'
    chckpnt_dir = parent_path + Log + 'unet_checkpoints/'
    ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
    saver.restore(sess, ckpt.model_checkpoint_path)
    _meas = _measure()
    copyfile('./test_semisupervised_multitast_hr_residual_asl.py',
             parent_path + Log + out_dir + 'test_semisupervised_multitast_hr_residual_asl.py')

    _image_class = image_class(data,
                               bunch_of_images_no=1,
                               is_training=0,
                               inp_size=asl_size,
                               out_size=pet_size)
    list_ssim = []
    list_name = []
    list_ssim_NC = []
    list_ssim_HC = []
    for scan in range(len(data)):
        ss = str(data[scan]['asl']).split("/")
        imm = _image_class.read_image(data[scan])
        try:
            os.mkdir(parent_path + Log + out_dir + ss[-3])
            os.mkdir(parent_path + Log + out_dir + ss[-3] + '/' + ss[-1].split(".")[0].split("ASL_")[1])
        except:
            a = 1

        for img_indx in range(np.shape(imm[3])[0]):
            print('img_indx:%s' % (img_indx))

            name = ss[-3] + '_' + ss[-2] + '_' + str(img_indx)
            # name = (ss[10] + '_' + ss[11] + '_' + ss[12].split('%')[0]).split('_CT')[0]
            tic = time.time()
            t1 = imm[3][img_indx,
                 int(imm[3].shape[-1]/2) - int(asl_size / 2) - 1:int(imm[3].shape[-1]/2)  + int(asl_size / 2),
                 int(imm[3].shape[-1] / 2) - int(asl_size / 2) - 1:int(imm[3].shape[-1]/2)  + int(asl_size / 2)]
            t1 = t1[np.newaxis, ..., np.newaxis]
            asl = imm[4][np.newaxis, img_indx,
                  int(imm[3].shape[-1] / 2) - int(asl_size / 2) - 1:int(imm[3].shape[-1] / 2) + int(asl_size / 2),
                  int(imm[3].shape[-1] / 2) - int(asl_size / 2) - 1:int(imm[3].shape[-1] / 2)+ int(asl_size / 2), np.newaxis]
            pet = imm[5][np.newaxis, img_indx,
                  int(imm[3].shape[-1] / 2)- int(pet_size / 2) - 1:int(imm[3].shape[-1] / 2) + int(pet_size / 2),
                  int(imm[3].shape[-1] / 2)- int(pet_size / 2) - 1:int(imm[3].shape[-1] / 2) + int(pet_size / 2), np.newaxis]

            [loss,pet_out,asl_out1,asl_out2] = sess.run([ssim_pet,pet_y,asl_y1,asl_y2],
                                   feed_dict={asl_plchld: asl,
                                              t1_plchld: t1,
                                              pet_plchld: pet,
                                              asl_out_plchld: asl[:,
                                                 int(asl_size / 2) - int(pet_size / 2) - 1:
                                                 int(asl_size / 2) + int(pet_size / 2),
                                                 int(asl_size / 2) - int(pet_size / 2) - 1:
                                                 int(asl_size / 2) + int(pet_size / 2), :],
                                              is_training: False,
                                              ave_loss_vali: -1,
                                              is_training_bn: False,
                                              hybrid_training_flag: False
                                              })

            ssim = 1 - loss
            list_ssim.append(ssim)
            str_nm = (ss[-3] + '_' + ss[-1].split(".")[0].split("ASL_")[1] + '_t1_' + name)
            if 'HN' in str_nm:
                list_ssim_NC.append(ssim)
            elif 'HY' in str_nm:
                list_ssim_HC.append(ssim)
            list_name.append(ss[-3] + '_' + ss[-1].split(".")[0].split("ASL_")[1] + '_t1_' + name)
            print(list_name[img_indx],': ',ssim)
            matplotlib.image.imsave(parent_path + Log + out_dir + ss[-3] + '/' + ss[-1].split(".")[0].split("ASL_")[
                1] + '/t1_' + name + '.png', np.squeeze(t1), cmap='gray')
            matplotlib.image.imsave(parent_path + Log + out_dir + ss[-3] + '/' + ss[-1].split(".")[0].split("ASL_")[
                1] + '/asl_' + name + '_' + '.png', np.squeeze(asl), cmap='gray')
            matplotlib.image.imsave(parent_path + Log + out_dir + ss[-3] + '/' + ss[-1].split(".")[0].split("ASL_")[
                1] + '/pet_' + name + '_' + '.png', np.squeeze(pet), cmap='gray')
            matplotlib.image.imsave(parent_path + Log + out_dir + ss[-3] + '/' + ss[-1].split(".")[0].split("ASL_")[
                1] + '/res_asl' + name + '_' + str(ssim) + '.png', np.squeeze(asl_out2), cmap='gray')
            matplotlib.image.imsave(parent_path + Log + out_dir + ss[-3] + '/' + ss[-1].split(".")[0].split("ASL_")[
                1] + '/res_pet' + name + '_' + str(ssim) + '.png', np.squeeze(pet_out), cmap='gray')

            elapsed = time.time() - tic

    df = pd.DataFrame(list_ssim,
                      columns=pd.Index(['ssim'],
                                       name='Genus')).round(2)
    a = {'HC': list_ssim_HC, 'NC': list_ssim_NC}
    df2 = pd.DataFrame.from_dict(a, orient='index')
    df2.transpose()

    # Create a Pandas Excel writer using XlsxWriter as the engine.
    writer = pd.ExcelWriter(parent_path + Log + out_dir + '/all_ssim.xlsx',
                            engine='xlsxwriter')
    writer2 = pd.ExcelWriter(parent_path + Log + out_dir + '/all_ssim.xlsx',
                             engine='xlsxwriter')
    # Convert the dataframe to an XlsxWriter Excel object.
    df.to_excel(writer, sheet_name='Sheet1')
    df2.to_excel(writer2, sheet_name='Sheet2')
    # Close the Pandas Excel writer and output the Excel file.
    writer.save()
    writer2.save()

    print(parent_path + Log + out_dir + '/all_ssim.xlsx')