示例#1
0
def test_all_nets(out_dir, Log, which_data):
    data_path_AMUC = "/exports/lkeb-hpc/syousefi/Data/ASL2PET_high_res/AMUC_high_res/"
    data_path_LUMC = "/exports/lkeb-hpc/syousefi/Data/ASL2PET_high_res/LUMC_high_res/"

    _rd = _read_data(data_path_AMUC, data_path_LUMC)
    fold = 2
    train_data, validation_data, test_data = _rd.read_data_path(average_no=0,
                                                                fold=fold)

    if which_data == 1:
        data = validation_data
    elif which_data == 2:
        data = test_data
    elif which_data == 3:
        data = train_data

    asl_plchld = tf.placeholder(tf.float32,
                                shape=[None, asl_size, asl_size, 1])
    t1_plchld = tf.placeholder(tf.float32, shape=[None, asl_size, asl_size, 1])
    pet_plchld = tf.placeholder(tf.float32,
                                shape=[None, pet_size, pet_size, 1])
    asl_out_plchld = tf.placeholder(tf.float32,
                                    shape=[None, pet_size, pet_size, 1])
    hybrid_training_flag = tf.placeholder(tf.bool, name='hybrid_training_flag')
    ave_loss_vali = tf.placeholder(tf.float32)
    residual_attention_map = tf.placeholder(
        tf.float32, shape=[None, asl_size, asl_size, 1])
    is_training = tf.placeholder(tf.bool, name='is_training')
    is_training_bn = tf.placeholder(tf.bool, name='is_training_bn')

    msdensnet = multi_stage_densenet()
    asl_y, pet_y, new_att_map = msdensnet.multi_stage_densenet(
        asl_img=asl_plchld,
        t1_img=t1_plchld,
        pet_img=pet_plchld,
        hybrid_training_flag=hybrid_training_flag,
        input_dim=asl_size,
        is_training=is_training,
        config=config,
        residual_attention_map=residual_attention_map)
    alpha = .84
    with tf.name_scope('cost'):
        ssim_asl = tf.reduce_mean(
            1 - SSIM(x1=asl_out_plchld, x2=asl_y, max_val=34.0)[0])
        loss_asl = alpha * ssim_asl + (1 - alpha) * tf.reduce_mean(
            huber(labels=asl_out_plchld, logit=asl_y))

        ssim_pet = tf.reduce_mean(
            1 - SSIM(x1=pet_plchld, x2=pet_y, max_val=2.1)[0])
        loss_pet = alpha * ssim_pet + (1 - alpha) * tf.reduce_mean(
            huber(labels=pet_plchld, logit=pet_y))

        cost_withpet = tf.reduce_mean(loss_asl + loss_pet)

        cost_withoutpet = loss_asl

        mse = MSE(x1=pet_plchld, x2=pet_y)
        psnr = PSNR(x1=pet_plchld, x2=pet_y)

    sess = tf.Session()
    saver = tf.train.Saver()
    parent_path = '/exports/lkeb-hpc/syousefi/Code/'
    chckpnt_dir = parent_path + Log + 'unet_checkpoints/'
    ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
    saver.restore(sess, ckpt.model_checkpoint_path)
    _meas = _measure()
    copyfile(
        './test_semisupervised_multitast_hr_rest_residual_skip_attention.py',
        parent_path + Log + out_dir +
        'test_semisupervised_multitast_hr_rest_residual_skip_attention.py')

    _image_class = image_class(data,
                               bunch_of_images_no=1,
                               is_training=0,
                               inp_size=asl_size,
                               out_size=pet_size)
    list_ssim_pet = []
    list_ssim_asl = []
    list_name = []
    list_ssim_NC = []
    list_mse_NC = []
    list_psnr_NC = []
    list_ssim_HC = []
    list_mse_HC = []
    list_psnr_HC = []
    for scan in range(len(data)):
        ss = str(data[scan]['asl']).split("/")
        imm = _image_class.read_image(data[scan])
        try:
            # print(parent_path + Log + out_dir + ss[-3])
            os.mkdir(parent_path + Log + out_dir + ss[-3])

        except:
            a = 1
        # print(parent_path + Log + out_dir + ss[-3] + '/' + ss[-1].split(".")[0].split("ASL_")[1])
        try:
            os.mkdir(parent_path + Log + out_dir + ss[-3] + '/' +
                     ss[-1].split(".")[0].split("ASL_")[1])
        except:
            try:
                os.mkdir(parent_path + Log + out_dir + ss[-3] + '/' +
                         ss[-1].split(".")[0])
            except:
                a = 1
        for img_indx in range(np.shape(imm[3])[0]):
            print('img_indx:%s' % (img_indx))

            name = ss[-3] + '_' + ss[-2] + '_' + str(img_indx)
            # name = (ss[10] + '_' + ss[11] + '_' + ss[12].split('%')[0]).split('_CT')[0]
            tic = time.time()
            t1 = imm[3][img_indx,
                        int(imm[3].shape[-1] / 2) - int(asl_size / 2) -
                        1:int(imm[3].shape[-1] / 2) + int(asl_size / 2),
                        int(imm[3].shape[-1] / 2) - int(asl_size / 2) -
                        1:int(imm[3].shape[-1] / 2) + int(asl_size / 2)]
            t1 = t1[np.newaxis, ..., np.newaxis]
            asl = imm[4][np.newaxis, img_indx,
                         int(imm[3].shape[-1] / 2) - int(asl_size / 2) -
                         1:int(imm[3].shape[-1] / 2) + int(asl_size / 2),
                         int(imm[3].shape[-1] / 2) - int(asl_size / 2) -
                         1:int(imm[3].shape[-1] / 2) + int(asl_size / 2),
                         np.newaxis]
            if np.size(imm[5]) > 1:
                pet = imm[5][np.newaxis, img_indx,
                             int(imm[3].shape[-1] / 2) - int(pet_size / 2) -
                             1:int(imm[3].shape[-1] / 2) + int(pet_size / 2),
                             int(imm[3].shape[-1] / 2) - int(pet_size / 2) -
                             1:int(imm[3].shape[-1] / 2) + int(pet_size / 2),
                             np.newaxis]
                hybrid_training_f = True
            else:
                hybrid_training_f = False
                pet = np.reshape([None] * pet_size * pet_size,
                                 [pet_size, pet_size])
                pet = pet[..., np.newaxis]
                pet = pet[np.newaxis, ...]

            # if hybrid_training_f:
            [loss, psnr1, mse1, pet_out, asl_out] = sess.run(
                [ssim_pet, psnr, mse, pet_y, asl_y],
                feed_dict={
                    asl_plchld:
                    asl,
                    t1_plchld:
                    t1,
                    pet_plchld:
                    pet,
                    asl_out_plchld:
                    asl[:,
                        int(asl_size / 2) - int(pet_size / 2) -
                        1:int(asl_size / 2) + int(pet_size / 2),
                        int(asl_size / 2) - int(pet_size / 2) -
                        1:int(asl_size / 2) + int(pet_size / 2), :],
                    is_training:
                    False,
                    ave_loss_vali:
                    -1,
                    is_training_bn:
                    False,
                    hybrid_training_flag:
                    False,
                    residual_attention_map:
                    (np.ones([1, asl_size, asl_size, 1]))
                })
            # if hybrid_training_f:
            ssim = 1 - loss
            list_ssim_pet.append(ssim)
            # else:
            # [loss,asl_out  ] = sess.run([ssim_asl,asl_y  ],
            #                                 feed_dict={asl_plchld: asl,
            #                                            t1_plchld: t1,
            #                                            pet_plchld: pet,
            #                                            asl_out_plchld: asl[:,
            #                                  int(asl_size / 2) - int(pet_size / 2) - 1:
            #                                  int(asl_size / 2) + int(pet_size / 2),
            #                                  int(asl_size / 2) - int(pet_size / 2) - 1:
            #                                  int(asl_size / 2) + int(pet_size / 2), :],
            #                                            is_training: False,
            #                                            ave_loss_vali: -1,
            #                                            is_training_bn: False,
            #                                            hybrid_training_flag:True })
            # ssim = 1 - loss
            # list_ssim_asl.append(ssim)

            try:
                str_nm = (ss[-3] + '_' +
                          ss[-1].split(".")[0].split("ASL_")[1] + '_t1_' +
                          name)
            except:
                str_nm = (ss[-3] + '_' + ss[-1].split(".")[0] + '_t1_' + name)

            if 'HN' in str_nm:
                list_ssim_NC.append(ssim)
                list_psnr_NC.append(psnr1)
                list_mse_NC.append(mse1)
            elif 'HY' in str_nm:
                list_ssim_HC.append(ssim)
                list_psnr_HC.append(psnr1)
                list_mse_HC.append(mse1)

            try:
                list_name.append(ss[-3] + '_' +
                                 ss[-1].split(".")[0].split("ASL_")[1] +
                                 '_t1_' + name)
                nm_fig = parent_path + Log + out_dir + ss[-3] + '/' + ss[
                    -1].split(".")[0].split("ASL_")[1]
            except:
                list_name.append(ss[-3] + '_' + ss[-1].split(".")[0] + '_t1_' +
                                 name)
                nm_fig = parent_path + Log + out_dir + ss[-3] + '/' + ss[
                    -1].split(".")[0]
            print(list_name[img_indx], ': ', ssim, ',PSNR: ', psnr1, ',MSE: ',
                  mse1)

            sitk.WriteImage(sitk.GetImageFromArray(np.squeeze(t1)),
                            nm_fig + '/t1_' + name + '.mha')
            sitk.WriteImage(sitk.GetImageFromArray(np.squeeze(asl)),
                            nm_fig + '/asl_' + name + '_' + '.mha')
            if hybrid_training_f:
                sitk.WriteImage(sitk.GetImageFromArray(np.squeeze(pet)),
                                nm_fig + '/pet_' + name + '_' + '.mha')
            sitk.WriteImage(
                sitk.GetImageFromArray(np.squeeze(pet_out)),
                nm_fig + '/res_pet' + name + '_' + str(ssim) + '.mha')
            sitk.WriteImage(
                sitk.GetImageFromArray(np.squeeze(asl_out)),
                nm_fig + '/res_asl' + name + '_' + str(ssim) + '.mha')

            elapsed = time.time() - tic

    df = pd.DataFrame(list_ssim_pet, columns=pd.Index(['ssim'],
                                                      name='Genus')).round(2)

    a = {
        'SSIM_HC': list_ssim_HC,
        'SSIM_NC': list_ssim_NC,
        'MSE_HC': list_mse_HC,
        'MSE_NC': list_mse_NC,
        'PSNR_HC': list_psnr_HC,
        'PSNR_NC': list_psnr_NC
    }
    df2 = pd.DataFrame.from_dict(a, orient='index')
    df2.transpose()

    # Create a Pandas Excel writer using XlsxWriter as the engine.
    writer = pd.ExcelWriter(parent_path + Log + out_dir + '/all_ssim.xlsx',
                            engine='xlsxwriter')
    writer2 = pd.ExcelWriter(parent_path + Log + out_dir + '/all_ssim.xlsx',
                             engine='xlsxwriter')
    # Convert the dataframe to an XlsxWriter Excel object.
    df.to_excel(writer, sheet_name='Sheet1')
    df2.to_excel(writer2, sheet_name='Sheet2')
    # Close the Pandas Excel writer and output the Excel file.
    writer.save()
    writer2.save()

    print(parent_path + Log + out_dir + '/all_ssim.xlsx')
示例#2
0
    def run_net(self):

        self.alpha_coeff = 1

        '''read path of the images for train, test, and validation'''
        _rd = _read_data(self.data_path)
        train_data, validation_data, test_data = _rd.read_data_path()

        # ======================================
        bunch_of_images_no = 1
        _image_class_vl = image_class(validation_data,
                                      bunch_of_images_no=bunch_of_images_no,
                                      is_training=0, inp_size=self.asl_size, out_size=self.pet_size)
        _patch_extractor_thread_vl = _patch_extractor_thread(_image_class=_image_class_vl,
                                                             img_no=bunch_of_images_no,
                                                             mutex=settings.mutex,
                                                             is_training=0,
                                                             )
        _fill_thread_vl = fill_thread(validation_data,
                                      _image_class_vl,
                                      mutex=settings.mutex,
                                      is_training=0,
                                      patch_extractor=_patch_extractor_thread_vl,
                                      )

        _read_thread_vl = read_thread(_fill_thread_vl, mutex=settings.mutex,
                                      validation_sample_no=self.validation_samples, is_training=0)
        _fill_thread_vl.start()
        _patch_extractor_thread_vl.start()
        _read_thread_vl.start()
        # ======================================
        bunch_of_images_no = 15
        _image_class_tr = image_class(train_data,
                                      bunch_of_images_no=bunch_of_images_no,
                                      is_training=1, inp_size=self.asl_size, out_size=self.pet_size
                                      )
        _patch_extractor_thread_tr = _patch_extractor_thread(_image_class=_image_class_tr,
                                                             img_no=bunch_of_images_no,
                                                             mutex=settings.mutex,
                                                             is_training=1,
                                                             )
        _fill_thread = fill_thread(train_data,
                                   _image_class_tr,
                                   mutex=settings.mutex,
                                   is_training=1,
                                   patch_extractor=_patch_extractor_thread_tr,
                                   )
        _read_thread = read_thread(_fill_thread, mutex=settings.mutex, is_training=1)
        _fill_thread.start()
        _patch_extractor_thread_tr.start()
        _read_thread.start()
        # ======================================
        # asl_plchld= tf.placeholder(tf.float32, shape=[None, None, None, 1])
        # t1_plchld= tf.placeholder(tf.float32, shape=[None, None, None, 1])
        # pet_plchld= tf.placeholder(tf.float32, shape=[None, None, None, 1])
        asl_plchld = tf.placeholder(tf.float32, shape=[self.batch_no, self.asl_size, self.asl_size, 1])
        t1_plchld = tf.placeholder(tf.float32, shape=[self.batch_no, self.asl_size, self.asl_size, 1])
        pet_plchld = tf.placeholder(tf.float32, shape=[self.batch_no, self.pet_size, self.pet_size, 1])
        asl_out_plchld = tf.placeholder(tf.float32, shape=[self.batch_no, self.pet_size, self.pet_size, 1])

        ave_loss_vali = tf.placeholder(tf.float32)

        is_training = tf.placeholder(tf.bool, name='is_training')
        is_training_bn = tf.placeholder(tf.bool, name='is_training_bn')

        # cnn_net = unet()  # create object
        # y,augmented_data = cnn_net.unet(t1=t1_plchld, asl=asl_plchld, pet=pet_plchld, is_training_bn=is_training_bn)
        msdensnet = multi_stage_densenet()
        asl_y,pet_y = msdensnet.multi_stage_densenet(asl_img=asl_plchld,
                                                   t1_img=t1_plchld,
                                                   pet_img=pet_plchld,
                                                   input_dim=77,
                                                   is_training=is_training,
                                                   config=self.config)


        show_img = asl_plchld[:, :, :, 0, np.newaxis]
        tf.summary.image('00: input_asl', show_img, 3)

        show_img = t1_plchld[:, :, :, 0, np.newaxis]
        tf.summary.image('01: input_t1', show_img, 3)

        show_img = pet_plchld[:, :, :, 0, np.newaxis]
        tf.summary.image('02: target_pet', show_img, 3)
        #
        show_img = asl_y[:, :, :, 0, np.newaxis]
        tf.summary.image('03: output_asl', show_img, 3)

        show_img = pet_y[:, :, :, 0, np.newaxis]
        tf.summary.image('03: output_pet', show_img, 3)
        # -----------------
        # show_img = loss_upsampling11[:, :, :, 0, np.newaxis]
        # tf.summary.image('04: loss_upsampling11', show_img, 3)
        # #
        # show_img = loss_upsampling22[:, :, :, 0, np.newaxis]
        # tf.summary.image('05: loss_upsampling22', show_img, 3)

        print('*****************************************')
        print('*****************************************')
        print('*****************************************')
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
        # devices = sess.list_devices()
        # print(devices)
        from tensorflow.python.client import device_lib
        print(device_lib.list_local_devices())
        print('*****************************************')
        print('*****************************************')
        print('*****************************************')

        train_writer = tf.summary.FileWriter(self.LOGDIR + '/train', graph=tf.get_default_graph())
        validation_writer = tf.summary.FileWriter(self.LOGDIR + '/validation', graph=sess.graph)
        try:
            os.mkdir(self.LOGDIR + 'code/')
            copyfile('./run_net.py', self.LOGDIR + 'code/run_net.py')
            copyfile('./submit_job.py', self.LOGDIR + 'code/submit_job.py')
            copyfile('./test_file.py', self.LOGDIR + 'code/test_file.py')
            shutil.copytree('./functions/', self.LOGDIR + 'code/functions/')
        except:
            a = 1

        # validation_writer.flush()
        extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1000)
        # train_writer.close()
        # validation_writer.close()

        loadModel = 0
        # self.loss = ssim_loss()
        alpha = .84
        with tf.name_scope('cost'):
            ssim_asl = tf.reduce_mean(1 - SSIM(x1=asl_out_plchld, x2=asl_y, max_val=34.0)[0])
            loss_asl = alpha * ssim_asl + (1 - alpha) * tf.reduce_mean(huber(labels=asl_out_plchld, logit=asl_y))

            ssim_pet = tf.reduce_mean(1 - SSIM(x1=pet_plchld, x2=pet_y, max_val=2.1)[0])
            loss_pet = alpha * ssim_pet + (1 - alpha) * tf.reduce_mean(huber(labels=pet_plchld, logit=pet_y))
            cost = loss_asl+loss_pet

        tf.summary.scalar("cost", cost)
        # tf.summary.scalar("denominator", denominator)
        extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(extra_update_ops):
            optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, ).minimize(cost)

        with tf.name_scope('validation'):
            average_validation_loss = ave_loss_vali

        tf.summary.scalar("average_validation_loss", average_validation_loss)
        sess.run(tf.global_variables_initializer())

        logging.debug('total number of variables %s' % (
            np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))

        summ = tf.summary.merge_all()

        point = 0
        itr1 = 0
        if loadModel:
            chckpnt_dir = ''
            ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
            saver.restore(sess, ckpt.model_checkpoint_path)
            point = np.int16(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            itr1 = point

        # with tf.Session() as sess:
        print("Number of trainable parameters: %d" % self.count_number_trainable_params())

        # patch_radius = 49
        '''loop for epochs'''

        for itr in range(self.total_iteration):
            while self.no_sample_per_each_itr * int(point / self.no_sample_per_each_itr) < self.sample_no:
                print("epoch #: %d" % (settings.epochs_no))
                startTime = time.time()
                step = 0
                self.beta_coeff = 1 + 1 * np.exp(-point / 2000)

                # =============validation================
                if itr1 % self.display_validation_step == 0:
                    '''Validation: '''
                    loss_validation = 0
                    acc_validation = 0
                    validation_step = 0
                    dsc_validation = 0

                    while (validation_step * self.batch_no_validation < settings.validation_totalimg_patch):

                        [validation_asl_slices, validation_pet_slices,
                         validation_t1_slices] = _image_class_vl.return_patches_validation(
                            validation_step * self.batch_no_validation,
                            (validation_step + 1) * self.batch_no_validation)
                        if (len(validation_asl_slices) < self.batch_no_validation) | (
                                len(validation_pet_slices) < self.batch_no_validation) | (
                                len(validation_t1_slices) < self.batch_no_validation):
                            _read_thread_vl.resume()
                            time.sleep(0.5)
                            # print('sleep 3 validation')
                            continue

                        tic = time.time()
                        [loss_vali] = sess.run([cost],
                                               feed_dict={asl_plchld: validation_asl_slices,
                                                          t1_plchld: validation_t1_slices,
                                                          pet_plchld: validation_pet_slices,
                                                          asl_out_plchld:validation_asl_slices[:,
                                                                         int(self.asl_size/2)-int(self.pet_size/2)-1:
                                                                         int(self.asl_size/2)+int(self.pet_size/2),
                                                                         int(self.asl_size/2)-int(self.pet_size/2)-1:
                                                                         int(self.asl_size/2)+int(self.pet_size/2),:],
                                                          is_training: False,
                                                          ave_loss_vali: -1,
                                                          is_training_bn: False,
                                                          })
                        elapsed = time.time() - tic
                        loss_validation += loss_vali
                        validation_step += 1
                        if np.isnan(dsc_validation) or np.isnan(loss_validation) or np.isnan(acc_validation):
                            print('nan problem')
                        process = psutil.Process(os.getpid())

                        print(
                            '%d - > %d: elapsed_time:%d  loss_validation: %f, memory_percent: %4s' % (
                                validation_step, validation_step * self.batch_no_validation
                                , elapsed, loss_vali, str(process.memory_percent()),
                            ))

                        # end while
                    settings.queue_isready_vl = False
                    acc_validation = acc_validation / (validation_step)
                    loss_validation = loss_validation / (validation_step)
                    dsc_validation = dsc_validation / (validation_step)
                    if np.isnan(dsc_validation) or np.isnan(loss_validation) or np.isnan(acc_validation):
                        print('nan problem')
                    _fill_thread_vl.kill_thread()
                    print('******Validation, step: %d , accuracy: %.4f, loss: %f*******' % (
                        itr1, acc_validation, loss_validation))

                    [sum_validation] = sess.run([summ],
                                                feed_dict={asl_plchld: validation_asl_slices,
                                                           t1_plchld: validation_t1_slices,
                                                           pet_plchld: validation_pet_slices,
                                                           asl_out_plchld:validation_asl_slices[:,
                                                                         int(self.asl_size/2)-int(self.pet_size/2)-1:
                                                                         int(self.asl_size/2)+int(self.pet_size/2),
                                                                         int(self.asl_size/2)-int(self.pet_size/2)-1:
                                                                         int(self.asl_size/2)+int(self.pet_size/2),:],
                                                           is_training: False,
                                                           ave_loss_vali: loss_validation,
                                                           is_training_bn: False,
                                                           })

                    validation_writer.add_summary(sum_validation, point)
                    validation_writer.flush()
                    print('end of validation---------%d' % (point))
                    # end if
                '''loop for training batches'''

                while (step * self.batch_no < self.no_sample_per_each_itr):

                    [train_asl_slices, train_pet_slices, train_t1_slices] = _image_class_tr.return_patches(
                        self.batch_no)

                    if (len(train_asl_slices) < self.batch_no) | (len(train_pet_slices) < self.batch_no) \
                            | (len(train_t1_slices) < self.batch_no):
                        # |(len(train_t1_slices)<self.batch_no):
                        time.sleep(0.5)
                        _read_thread.resume()
                        continue

                    tic = time.time()

                    [loss_train1, opt, ] = sess.run([cost, optimizer, ],
                                                    feed_dict={asl_plchld: train_asl_slices,
                                                               t1_plchld: train_t1_slices,
                                                               pet_plchld: train_pet_slices,
                                                               asl_out_plchld: train_asl_slices[:,
                                                                         int(self.asl_size/2)-int(self.pet_size/2)-1:
                                                                         int(self.asl_size/2)+int(self.pet_size/2),
                                                                         int(self.asl_size/2)-int(self.pet_size/2)-1:
                                                                         int(self.asl_size/2)+int(self.pet_size/2),:],
                                                               is_training: True,
                                                               ave_loss_vali: -1,
                                                               is_training_bn: True})
                    elapsed = time.time() - tic
                    [sum_train] = sess.run([summ],
                                           feed_dict={asl_plchld: train_asl_slices,
                                                      t1_plchld: train_t1_slices,
                                                      pet_plchld: train_pet_slices,
                                                      asl_out_plchld: train_asl_slices[:,
                                                                         int(self.asl_size/2)-int(self.pet_size/2)-1:
                                                                         int(self.asl_size/2)+int(self.pet_size/2),
                                                                         int(self.asl_size/2)-int(self.pet_size/2)-1:
                                                                         int(self.asl_size/2)+int(self.pet_size/2),:],
                                                      is_training: False,
                                                      ave_loss_vali: loss_train1,
                                                      is_training_bn: False
                                                      })
                    train_writer.add_summary(sum_train, point)
                    train_writer.flush()
                    step = step + 1

                    process = psutil.Process(os.getpid())

                    print(
                        'point: %d, elapsed_time:%d step*self.batch_no:%f , LR: %.15f, loss_train1:%f,memory_percent: %4s' % (
                            int((point)), elapsed,
                            step * self.batch_no, self.learning_rate, loss_train1,
                            str(process.memory_percent())))

                    point = int((point))  # (self.no_sample_per_each_itr/self.batch_no)*itr1+step

                    if point % 100 == 0:
                        '''saveing model inter epoch'''
                        chckpnt_path = os.path.join(self.chckpnt_dir,
                                                    ('densenet_unet_inter_epoch%d_point%d.ckpt' % (itr, point)))
                        saver.save(sess, chckpnt_path, global_step=point)

                    itr1 = itr1 + 1
                    point = point + 1

            endTime = time.time()

            # ==============end of epoch:

            '''saveing model after each epoch'''
            chckpnt_path = os.path.join(self.chckpnt_dir, 'densenet_unet.ckpt')
            saver.save(sess, chckpnt_path, global_step=itr)
            print("End of epoch----> %d, elapsed time: %d" % (settings.epochs_no, endTime - startTime))
示例#3
0
def test_all_nets(out_dir, Log, which_data):
    data_path = "/exports/lkeb-hpc/syousefi/Data/asl_pet/"
    _rd = _read_data(data_path)

    train_data, validation_data, test_data = _rd.read_data_path()

    if which_data == 1:
        data = validation_data
    elif which_data == 2:
        data = test_data
    elif which_data == 3:
        data = train_data

    asl_plchld = tf.placeholder(tf.float32,
                                shape=[None, asl_size, asl_size, 1])
    t1_plchld = tf.placeholder(tf.float32, shape=[None, asl_size, asl_size, 1])
    pet_plchld = tf.placeholder(tf.float32,
                                shape=[None, pet_size, pet_size, 1])
    asl_out_plchld = tf.placeholder(tf.float32,
                                    shape=[None, pet_size, pet_size, 1])

    ave_loss_vali = tf.placeholder(tf.float32)

    is_training = tf.placeholder(tf.bool, name='is_training')
    is_training_bn = tf.placeholder(tf.bool, name='is_training_bn')

    msdensnet = multi_stage_densenet()
    asl_y, pet_y = msdensnet.multi_stage_densenet(asl_img=asl_plchld,
                                                  t1_img=t1_plchld,
                                                  pet_img=pet_plchld,
                                                  input_dim=77,
                                                  is_training=is_training,
                                                  config=config)
    alpha = .84
    with tf.name_scope('cost'):
        ssim_asl = tf.reduce_mean(
            1 - SSIM(x1=asl_out_plchld, x2=asl_y, max_val=34.0)[0])
        loss_asl = alpha * ssim_asl + (1 - alpha) * tf.reduce_mean(
            huber(labels=asl_out_plchld, logit=asl_y))

        ssim_pet = tf.reduce_mean(
            1 - SSIM(x1=pet_plchld, x2=pet_y, max_val=2.1)[0])
        loss_pet = alpha * ssim_pet + (1 - alpha) * tf.reduce_mean(
            huber(labels=pet_plchld, logit=pet_y))
        cost = loss_asl + loss_pet

    sess = tf.Session()
    saver = tf.train.Saver()
    parent_path = '/exports/lkeb-hpc/syousefi/Code/'
    chckpnt_dir = parent_path + Log + 'unet_checkpoints/'
    ckpt = tf.train.get_checkpoint_state(chckpnt_dir)
    saver.restore(sess, ckpt.model_checkpoint_path)
    _meas = _measure()
    copyfile('./test_multitask.py',
             parent_path + Log + out_dir + 'test_multitask.py')

    _image_class = image_class(data,
                               bunch_of_images_no=1,
                               is_training=0,
                               inp_size=asl_size,
                               out_size=pet_size)
    list_ssim = []
    list_name = []
    for scan in range(len(data)):
        ss = str(data[scan]['asl']).split("/")
        imm = _image_class.read_image(data[scan])
        try:
            os.mkdir(parent_path + Log + out_dir + ss[-3])
        except:
            a = 1
        os.mkdir(parent_path + Log + out_dir + ss[-3] + '/' +
                 ss[-1].split(".")[0].split("ASL_")[1])
        for img_indx in range(np.shape(imm[3])[0]):
            print('img_indx:%s' % (img_indx))

            name = ss[-3] + '_' + ss[-2] + '_' + str(img_indx)
            # name = (ss[10] + '_' + ss[11] + '_' + ss[12].split('%')[0]).split('_CT')[0]
            tic = time.time()
            t1 = imm[3][img_indx,
                        40 - int(asl_size / 2) - 1:40 + int(asl_size / 2),
                        40 - int(asl_size / 2) - 1:40 + int(asl_size / 2)]
            t1 = t1[np.newaxis, ..., np.newaxis]
            asl = imm[4][np.newaxis, img_indx,
                         40 - int(asl_size / 2) - 1:40 + int(asl_size / 2),
                         40 - int(asl_size / 2) - 1:40 + int(asl_size / 2),
                         np.newaxis]
            pet = imm[5][np.newaxis, img_indx,
                         40 - int(pet_size / 2) - 1:40 + int(pet_size / 2),
                         40 - int(pet_size / 2) - 1:40 + int(pet_size / 2),
                         np.newaxis]
            [loss, asl_out, pet_out] = sess.run(
                [ssim_pet, asl_y, pet_y],
                feed_dict={
                    asl_plchld:
                    asl,
                    t1_plchld:
                    t1,
                    pet_plchld:
                    pet,
                    asl_out_plchld:
                    asl[:,
                        int(asl_size / 2) - int(pet_size / 2) -
                        1:int(asl_size / 2) + int(pet_size / 2),
                        int(asl_size / 2) - int(pet_size / 2) -
                        1:int(asl_size / 2) + int(pet_size / 2), :],
                    is_training:
                    False,
                    ave_loss_vali:
                    -1,
                    is_training_bn:
                    False
                })

            # plt.imshow(np.squeeze(out))
            # plt.figure()
            # plt.imshow(np.squeeze(pet))
            ssim = 1 - loss
            list_ssim.append(ssim)
            list_name.append(ss[-3] + '_' +
                             ss[-1].split(".")[0].split("ASL_")[1] + '_t1_' +
                             name)
            print(list_name[img_indx], ': ', ssim)
            matplotlib.image.imsave(
                parent_path + Log + out_dir + ss[-3] + '/' +
                ss[-1].split(".")[0].split("ASL_")[1] + '/t1_' + name + '.png',
                np.squeeze(t1),
                cmap='gray')
            matplotlib.image.imsave(
                parent_path + Log + out_dir + ss[-3] + '/' +
                ss[-1].split(".")[0].split("ASL_")[1] + '/asl_' + name + '_' +
                '.png',
                np.squeeze(asl),
                cmap='gray')
            matplotlib.image.imsave(
                parent_path + Log + out_dir + ss[-3] + '/' +
                ss[-1].split(".")[0].split("ASL_")[1] + '/pet_' + name + '_' +
                '.png',
                np.squeeze(pet),
                cmap='gray')
            matplotlib.image.imsave(
                parent_path + Log + out_dir + ss[-3] + '/' +
                ss[-1].split(".")[0].split("ASL_")[1] + '/res_asl' + name +
                '_' + str(ssim) + '.png',
                np.squeeze(asl_out),
                cmap='gray')
            matplotlib.image.imsave(
                parent_path + Log + out_dir + ss[-3] + '/' +
                ss[-1].split(".")[0].split("ASL_")[1] + '/res_pet' + name +
                '_' + str(ssim) + '.png',
                np.squeeze(pet_out),
                cmap='gray')

            elapsed = time.time() - tic

    df = pd.DataFrame(list_ssim, columns=pd.Index(['ssim'],
                                                  name='Genus')).round(2)
    # Create a Pandas Excel writer using XlsxWriter as the engine.
    writer = pd.ExcelWriter(parent_path + Log + out_dir + '/all_ssim.xlsx',
                            engine='xlsxwriter')
    # Convert the dataframe to an XlsxWriter Excel object.
    df.to_excel(writer, sheet_name='Sheet1')
    # Close the Pandas Excel writer and output the Excel file.
    writer.save()

    print(parent_path + Log + out_dir + '/all_ssim.xlsx')