コード例 #1
0
    def test(self, dicom_dir):
        # X = tf.placeholder(shape=[batch_size, input_shape[0], input_shape[1], input_shape[2]], dtype=tf.float32)
        test_input_shape = input_shape
        test_batch_size = batch_size * 2
        threshold = tf.placeholder(tf.float32)
        training = tf.placeholder(tf.bool)
        X = tf.placeholder(shape=[
            test_batch_size, test_input_shape[0], test_input_shape[1],
            test_input_shape[2]
        ],
                           dtype=tf.float32)
        with tf.variable_scope('ae'):
            Y_pred, Y_pred_modi, Y_pred_nosig = self.ae_u(
                X, training, test_batch_size, threshold)

        print tools.Ops.variable_count()
        sum_merged = tf.summary.merge_all()
        saver = tf.train.Saver(max_to_keep=1)
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.visible_device_list = GPU0
        with tf.Session(config=config) as sess:
            # if os.path.exists(self.train_models_dir):
            #     saver.restore(sess, self.train_models_dir + 'model.cptk')
            sum_writer_train = tf.summary.FileWriter(self.train_sum_dir,
                                                     sess.graph)
            sum_write_test = tf.summary.FileWriter(self.test_sum_dir)

            if os.path.exists(self.train_models_dir) and os.path.isfile(
                    self.train_models_dir + 'model.cptk.data-00000-of-00001'):
                print "restoring saved model"
                saver.restore(sess, self.train_models_dir + 'model.cptk')
            else:
                sess.run(tf.global_variables_initializer())
            test_data = tools.Test_data(dicom_dir, input_shape)
            test_data.organize_blocks()
            block_numbers = test_data.blocks.keys()
            for i in range(0, len(block_numbers), test_batch_size):
                batch_numbers = []
                if i + test_batch_size < len(block_numbers):
                    temp_input = np.zeros([
                        test_batch_size, input_shape[0], input_shape[1],
                        input_shape[2]
                    ])
                    for j in range(test_batch_size):
                        temp_num = block_numbers[i + j]
                        temp_block = test_data.blocks[temp_num]
                        batch_numbers.append(temp_num)
                        block_array = temp_block.load_data()
                        block_shape = np.shape(block_array)
                        temp_input[j, 0:block_shape[0], 0:block_shape[1],
                                   0:block_shape[2]] += block_array
                    Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run(
                        [Y_pred, Y_pred_modi, Y_pred_nosig],
                        feed_dict={
                            X: temp_input,
                            training: False,
                            threshold: upper_threshold
                        })
                    for j in range(test_batch_size):
                        test_data.upload_result(batch_numbers[j],
                                                Y_temp_modi[j, :, :, :])
                else:
                    temp_batch_size = len(block_numbers) - i
                    temp_input = np.zeros([
                        temp_batch_size, input_shape[0], input_shape[1],
                        input_shape[2]
                    ])
                    for j in range(temp_batch_size):
                        temp_num = block_numbers[i + j]
                        temp_block = test_data.blocks[temp_num]
                        batch_numbers.append(temp_num)
                        block_array = temp_block.load_data()
                        block_shape = np.shape(block_array)
                        temp_input[j, 0:block_shape[0], 0:block_shape[1],
                                   0:block_shape[2]] += block_array
                    X_temp = tf.placeholder(shape=[
                        temp_batch_size, input_shape[0], input_shape[1],
                        input_shape[2]
                    ],
                                            dtype=tf.float32)
                    with tf.variable_scope('ae', reuse=True):
                        Y_pred_temp, Y_pred_modi_temp, Y_pred_nosig_temp = self.ae_u(
                            X_temp, training, temp_batch_size, threshold)
                    Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run(
                        [Y_pred_temp, Y_pred_modi_temp, Y_pred_nosig_temp],
                        feed_dict={
                            X_temp: temp_input,
                            training: False,
                            threshold: upper_threshold
                        })
                    for j in range(temp_batch_size):
                        test_data.upload_result(batch_numbers[j],
                                                Y_temp_modi[j, :, :, :])
            test_result_array = test_data.get_result()
            print "result shape: ", np.shape(test_result_array)
            r_s = np.shape(test_result_array)  # result shape
            e_t = 10  # edge thickness
            to_be_transformed = np.zeros(r_s, np.float32)
            to_be_transformed[e_t:r_s[0] - e_t, e_t:r_s[1] - e_t, 0:r_s[2] -
                              e_t] += test_result_array[e_t:r_s[0] - e_t,
                                                        e_t:r_s[1] - e_t,
                                                        0:r_s[2] - e_t]
            print np.max(to_be_transformed)
            print np.min(to_be_transformed)
            final_img = ST.GetImageFromArray(
                np.transpose(to_be_transformed, [2, 1, 0]))
            final_img.SetSpacing(test_data.space)
            print "writing final testing result"
            print './test_result/test_result_final.vtk'
            ST.WriteImage(final_img, './test_result/test_result_final.vtk')
            return final_img
コード例 #2
0
 def full_testing(self, sess, X, w, threshold, test_merge_op,
                  sum_write_test, training, weight_for, total_acc, Y_pred,
                  Y_pred_modi, Y_pred_nosig, epoch):
     print '********************** FULL TESTING ********************************'
     # X = tf.placeholder(shape=[batch_size, input_shape[0], input_shape[1], input_shape[2]], dtype=tf.float32)
     # w = tf.placeholder(tf.float32)
     # threshold = tf.placeholder(tf.float32)
     time_begin = time.time()
     origin_data = read_dicoms(test_dir + "original1")
     mask_dir = test_dir + "airway"
     test_batch_size = batch_size
     # test_data = tools.Test_data(dicom_dir,input_shape)
     test_data = tools.Test_data(origin_data, input_shape, 'vtk_data')
     test_data.organize_blocks()
     test_mask = read_dicoms(mask_dir)
     array_mask = ST.GetArrayFromImage(test_mask)
     array_mask = np.transpose(array_mask, (2, 1, 0))
     print "mask shape: ", np.shape(array_mask)
     block_numbers = test_data.blocks.keys()
     for i in range(0, len(block_numbers), test_batch_size):
         batch_numbers = []
         if i + test_batch_size < len(block_numbers):
             temp_input = np.zeros([
                 test_batch_size, input_shape[0], input_shape[1],
                 input_shape[2]
             ])
             for j in range(test_batch_size):
                 temp_num = block_numbers[i + j]
                 temp_block = test_data.blocks[temp_num]
                 batch_numbers.append(temp_num)
                 block_array = temp_block.load_data()
                 block_shape = np.shape(block_array)
                 temp_input[j, 0:block_shape[0], 0:block_shape[1],
                            0:block_shape[2]] += block_array
             Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run(
                 [Y_pred, Y_pred_modi, Y_pred_nosig],
                 feed_dict={
                     X: temp_input,
                     training: False,
                     w: weight_for,
                     threshold: upper_threshold + test_extra_threshold
                 })
             for j in range(test_batch_size):
                 test_data.upload_result(batch_numbers[j],
                                         Y_temp_modi[j, :, :, :])
         else:
             temp_batch_size = len(block_numbers) - i
             temp_input = np.zeros([
                 temp_batch_size, input_shape[0], input_shape[1],
                 input_shape[2]
             ])
             for j in range(temp_batch_size):
                 temp_num = block_numbers[i + j]
                 temp_block = test_data.blocks[temp_num]
                 batch_numbers.append(temp_num)
                 block_array = temp_block.load_data()
                 block_shape = np.shape(block_array)
                 temp_input[j, 0:block_shape[0], 0:block_shape[1],
                            0:block_shape[2]] += block_array
             X_temp = tf.placeholder(shape=[
                 temp_batch_size, input_shape[0], input_shape[1],
                 input_shape[2]
             ],
                                     dtype=tf.float32)
             with tf.variable_scope('generator', reuse=True):
                 Y_pred_temp, Y_pred_modi_temp, Y_pred_nosig_temp = self.ae_u(
                     X_temp, training, temp_batch_size, threshold)
             Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run(
                 [Y_pred_temp, Y_pred_modi_temp, Y_pred_nosig_temp],
                 feed_dict={
                     X_temp: temp_input,
                     training: False,
                     w: weight_for,
                     threshold: upper_threshold + test_extra_threshold
                 })
             for j in range(temp_batch_size):
                 test_data.upload_result(batch_numbers[j],
                                         Y_temp_modi[j, :, :, :])
     test_result_array = test_data.get_result()
     print "result shape: ", np.shape(test_result_array)
     to_be_transformed = self.post_process(test_result_array)
     if epoch == total_test_epoch:
         mask_img = ST.GetImageFromArray(np.transpose(
             array_mask, [2, 1, 0]))
         mask_img.SetSpacing(test_data.space)
         ST.WriteImage(mask_img, self.test_results_dir + 'test_mask.vtk')
     test_IOU = 2 * np.sum(to_be_transformed * array_mask) / (
         np.sum(to_be_transformed) + np.sum(array_mask))
     test_summary = sess.run(test_merge_op, feed_dict={total_acc: test_IOU})
     sum_write_test.add_summary(test_summary, global_step=epoch)
     print "IOU accuracy: ", test_IOU
     time_end = time.time()
     print '******************** time of full testing: ' + str(
         time_end - time_begin) + 's ********************'
コード例 #3
0
    def train(self, configure):
        data = tools.Data(configure, epoch_walked)
        best_acc = 0
        # X = tf.placeholder(shape=[batch_size, input_shape[0], input_shape[1], input_shape[2]], dtype=tf.float32)
        X = tf.placeholder(
            shape=[batch_size, input_shape[0], input_shape[1], input_shape[2]],
            dtype=tf.float32)
        # Y = tf.placeholder(shape=[batch_size, output_shape[0], output_shape[1], output_shape[2]], dtype=tf.float32)
        Y = tf.placeholder(shape=[
            batch_size, output_shape[0], output_shape[1], output_shape[2]
        ],
                           dtype=tf.float32)
        print X.get_shape()
        lr = tf.placeholder(tf.float32)
        training = tf.placeholder(tf.bool)
        threshold = tf.placeholder(tf.float32)
        with tf.variable_scope('ae'):
            Y_pred, Y_pred_modi, Y_pred_nosig = self.ae_u(
                X, training, batch_size, threshold)

        with tf.variable_scope('dis'):
            XY_real_pair = self.dis(X, Y, training)
        with tf.variable_scope('dis', reuse=True):
            XY_fake_pair = self.dis(X, Y_pred, training)

        with tf.device('/gpu:' + GPU0):
            ################################ ae loss
            Y_ = tf.reshape(Y, shape=[batch_size, -1])
            Y_pred_modi_ = tf.reshape(Y_pred_modi, shape=[batch_size, -1])
            w = tf.placeholder(
                tf.float32)  # power of foreground against background
            ae_loss = tf.reduce_mean(
                -tf.reduce_mean(w * Y_ * tf.log(Y_pred_modi_ + 1e-8),
                                reduction_indices=[1]) -
                tf.reduce_mean((1 - w) *
                               (1 - Y_) * tf.log(1 - Y_pred_modi_ + 1e-8),
                               reduction_indices=[1]))
            sum_ae_loss = tf.summary.scalar('ae_loss', ae_loss)

            ################################ wgan loss
            gan_g_loss = -tf.reduce_mean(XY_fake_pair)
            gan_d_loss = tf.reduce_mean(XY_fake_pair) - tf.reduce_mean(
                XY_real_pair)
            sum_gan_g_loss = tf.summary.scalar('gan_g_loss', gan_g_loss)
            sum_gan_d_loss = tf.summary.scalar('gan_d_loss', gan_d_loss)
            alpha = tf.random_uniform(shape=[
                batch_size, input_shape[0] * input_shape[1] * input_shape[2]
            ],
                                      minval=0.0,
                                      maxval=1.0)

            Y_pred_ = tf.reshape(Y_pred, shape=[batch_size, -1])
            differences_ = Y_pred_ - Y_
            interpolates = Y_ + alpha * differences_
            with tf.variable_scope('dis', reuse=True):
                XY_fake_intep = self.dis(X, interpolates, training)
            gradients = tf.gradients(XY_fake_intep, [interpolates])[0]
            slopes = tf.sqrt(
                tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
            gradient_penalty = tf.reduce_mean((slopes - 1.0)**2)
            gan_d_loss += 10 * gradient_penalty

            #################################  ae + gan loss
            gan_g_w = 5
            ae_w = 100 - gan_g_w
            ae_gan_g_loss = ae_w * ae_loss + gan_g_w * gan_g_loss

        with tf.device('/gpu:' + GPU0):
            ae_var = [
                var for var in tf.trainable_variables()
                if var.name.startswith('ae')
            ]
            dis_var = [
                var for var in tf.trainable_variables()
                if var.name.startswith('dis')
            ]
            ae_g_optim = tf.train.AdamOptimizer(learning_rate=lr,
                                                beta1=0.9,
                                                beta2=0.999,
                                                epsilon=1e-8).minimize(
                                                    ae_gan_g_loss,
                                                    var_list=ae_var)
            dis_optim = tf.train.AdamOptimizer(learning_rate=lr,
                                               beta1=0.9,
                                               beta2=0.999,
                                               epsilon=1e-8).minimize(
                                                   gan_d_loss,
                                                   var_list=dis_var)

        print tools.Ops.variable_count()
        sum_merged = tf.summary.merge_all()

        saver = tf.train.Saver(max_to_keep=1)
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.visible_device_list = GPU0
        with tf.Session(config=config) as sess:
            # if os.path.exists(self.train_models_dir):
            #     try:
            #         saver.restore(sess,self.train_models_dir+'model.cptk')
            #     except Exception,e:
            #         saver.restore(sess,'./regular/'+'model.cptk')
            sum_writer_train = tf.summary.FileWriter(self.train_sum_dir,
                                                     sess.graph)
            sum_write_test = tf.summary.FileWriter(self.test_sum_dir)

            if os.path.isfile(self.train_models_dir +
                              'model.cptk.data-00000-of-00001'):
                print "restoring saved model"
                saver.restore(sess, self.train_models_dir + 'model.cptk')
            else:
                sess.run(tf.global_variables_initializer())

            learning_rate_g = ori_lr * pow(power, (epoch_walked / 4))
            for epoch in range(epoch_walked, 15000):
                # data.shuffle_X_Y_files(label='train')
                #### select data randomly each 10 epochs
                if epoch % 2 == 0 and epoch > 0:
                    del data
                    gc.collect()
                    data = tools.Data(configure, epoch)
                #### full testing
                # ...
                train_amount = len(data.train_numbers)
                test_amount = len(data.test_numbers)
                if train_amount >= test_amount and train_amount > 0 and test_amount > 0 and data.total_train_batch_num > 0 and data.total_test_seq_batch > 0:
                    weight_for = 0.35 * (1 - epoch * 1.0 / 15000) + 0.5
                    if epoch % 4 == 0:
                        print '********************** FULL TESTING ********************************'
                        time_begin = time.time()
                        lung_img = ST.ReadImage('./WANG_REN/lung_img.vtk')
                        mask_dir = "./WANG_REN/airway"
                        test_batch_size = batch_size
                        # test_data = tools.Test_data(dicom_dir,input_shape)
                        test_data = tools.Test_data(lung_img, input_shape,
                                                    'vtk_data')
                        test_data.organize_blocks()
                        test_mask = read_dicoms(mask_dir)
                        array_mask = ST.GetArrayFromImage(test_mask)
                        array_mask = np.transpose(array_mask, (2, 1, 0))
                        print "mask shape: ", np.shape(array_mask)
                        time1 = time.time()
                        block_numbers = test_data.blocks.keys()
                        for i in range(0, len(block_numbers), test_batch_size):
                            batch_numbers = []
                            if i + test_batch_size < len(block_numbers):
                                temp_input = np.zeros([
                                    test_batch_size, input_shape[0],
                                    input_shape[1], input_shape[2]
                                ])
                                for j in range(test_batch_size):
                                    temp_num = block_numbers[i + j]
                                    temp_block = test_data.blocks[temp_num]
                                    batch_numbers.append(temp_num)
                                    block_array = temp_block.load_data()
                                    block_shape = np.shape(block_array)
                                    temp_input[j, 0:block_shape[0],
                                               0:block_shape[1],
                                               0:block_shape[2]] += block_array
                                Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run(
                                    [Y_pred, Y_pred_modi, Y_pred_nosig],
                                    feed_dict={
                                        X: temp_input,
                                        training: False,
                                        w: weight_for,
                                        threshold: upper_threshold
                                    })
                                for j in range(test_batch_size):
                                    test_data.upload_result(
                                        batch_numbers[j],
                                        Y_temp_modi[j, :, :, :])
                            else:
                                temp_batch_size = len(block_numbers) - i
                                temp_input = np.zeros([
                                    temp_batch_size, input_shape[0],
                                    input_shape[1], input_shape[2]
                                ])
                                for j in range(temp_batch_size):
                                    temp_num = block_numbers[i + j]
                                    temp_block = test_data.blocks[temp_num]
                                    batch_numbers.append(temp_num)
                                    block_array = temp_block.load_data()
                                    block_shape = np.shape(block_array)
                                    temp_input[j, 0:block_shape[0],
                                               0:block_shape[1],
                                               0:block_shape[2]] += block_array
                                X_temp = tf.placeholder(shape=[
                                    temp_batch_size, input_shape[0],
                                    input_shape[1], input_shape[2]
                                ],
                                                        dtype=tf.float32)
                                with tf.variable_scope('ae', reuse=True):
                                    Y_pred_temp, Y_pred_modi_temp, Y_pred_nosig_temp = self.ae_u(
                                        X_temp, training, temp_batch_size,
                                        threshold)
                                Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run(
                                    [
                                        Y_pred_temp, Y_pred_modi_temp,
                                        Y_pred_nosig_temp
                                    ],
                                    feed_dict={
                                        X_temp: temp_input,
                                        training: False,
                                        w: weight_for,
                                        threshold: upper_threshold
                                    })
                                for j in range(temp_batch_size):
                                    test_data.upload_result(
                                        batch_numbers[j],
                                        Y_temp_modi[j, :, :, :])
                        test_result_array = test_data.get_result()
                        print "result shape: ", np.shape(test_result_array)
                        r_s = np.shape(test_result_array)  # result shape
                        e_t = 10  # edge thickness
                        to_be_transformed = np.zeros(r_s, np.float32)
                        to_be_transformed[
                            e_t:r_s[0] - e_t, e_t:r_s[1] - e_t, e_t:r_s[2] -
                            e_t] += test_result_array[e_t:r_s[0] - e_t,
                                                      e_t:r_s[1] - e_t,
                                                      e_t:r_s[2] - e_t]
                        print np.max(to_be_transformed)
                        print np.min(to_be_transformed)
                        final_img = ST.GetImageFromArray(
                            np.transpose(to_be_transformed, [2, 1, 0]))
                        final_img.SetSpacing(test_data.space)
                        print "writing full testing result"
                        print '/usr/analyse_airway/test_result/test_result' + str(
                            epoch) + '.vtk'
                        ST.WriteImage(
                            final_img,
                            '/usr/analyse_airway/test_result/test_result' +
                            str(epoch) + '.vtk')
                        if epoch == 0:
                            mask_img = ST.GetImageFromArray(
                                np.transpose(array_mask, [2, 1, 0]))
                            mask_img.SetSpacing(test_data.space)
                            ST.WriteImage(
                                mask_img,
                                '/usr/analyse_airway/test_result/test_mask.vtk'
                            )
                        test_IOU = 2 * np.sum(
                            to_be_transformed * array_mask) / (
                                np.sum(to_be_transformed) + np.sum(array_mask))
                        print "IOU accuracy: ", test_IOU
                        time_end = time.time()
                        print '******************** time of full testing: ' + str(
                            time_end - time_begin) + 's ********************'
                    data.shuffle_X_Y_pairs()
                    total_train_batch_num = data.total_train_batch_num
                    # train_files=data.X_train_files
                    # test_files=data.X_test_files
                    # total_train_batch_num = 500
                    print "total_train_batch_num:", total_train_batch_num
                    for i in range(total_train_batch_num):

                        #### training
                        X_train_batch, Y_train_batch = data.load_X_Y_voxel_train_next_batch(
                        )
                        # X_train_batch, Y_train_batch = data.load_X_Y_voxel_grids_train_next_batch()
                        # Y_train_batch=np.reshape(Y_train_batch,[batch_size, output_shape[0], output_shape[1], output_shape[2], 1])
                        gan_d_loss_c, = sess.run(
                            [gan_d_loss],
                            feed_dict={
                                X: X_train_batch,
                                Y: Y_train_batch,
                                training: False,
                                w: weight_for,
                                threshold: upper_threshold
                            })
                        ae_loss_c, gan_g_loss_c, sum_train = sess.run(
                            [ae_loss, gan_g_loss, sum_merged],
                            feed_dict={
                                X: X_train_batch,
                                Y: Y_train_batch,
                                training: False,
                                w: weight_for,
                                threshold: upper_threshold
                            })
                        if epoch % 4 == 0 and epoch > 0 and i == 0:
                            learning_rate_g = learning_rate_g * power
                        sess.run(
                            [ae_g_optim],
                            feed_dict={
                                X: X_train_batch,
                                threshold: upper_threshold,
                                Y: Y_train_batch,
                                lr: learning_rate_g,
                                training: True,
                                w: weight_for
                            })
                        if epoch <= 5:
                            sess.run(
                                [dis_optim],
                                feed_dict={
                                    X: X_train_batch,
                                    threshold: upper_threshold,
                                    Y: Y_train_batch,
                                    lr: learning_rate_g,
                                    training: True,
                                    w: weight_for
                                })
                        elif epoch <= 20:
                            sess.run(
                                [dis_optim],
                                feed_dict={
                                    X: X_train_batch,
                                    threshold: upper_threshold,
                                    Y: Y_train_batch,
                                    lr: learning_rate_g,
                                    training: True,
                                    w: weight_for
                                })
                        else:
                            sess.run(
                                [dis_optim],
                                feed_dict={
                                    X: X_train_batch,
                                    threshold: upper_threshold,
                                    Y: Y_train_batch,
                                    lr: learning_rate_g,
                                    training: True,
                                    w: weight_for
                                })

                        sum_writer_train.add_summary(
                            sum_train, epoch * total_train_batch_num + i)
                        if i % 2 == 0:
                            print "epoch:", epoch, " i:", i, " train ae loss:", ae_loss_c, " gan g loss:", gan_g_loss_c, " gan d loss:", gan_d_loss_c, " learning rate: ", learning_rate_g
                        #### testing
                        if i % 20 == 0 and epoch % 1 == 0:
                            try:
                                X_test_batch, Y_test_batch = data.load_X_Y_voxel_test_next_batch(
                                    fix_sample=False)
                                # Y_test_batch = np.reshape(Y_test_batch,[batch_size, output_shape[0], output_shape[1], output_shape[2], 1])
                                ae_loss_t,gan_g_loss_t,gan_d_loss_t, Y_test_pred,Y_test_modi, Y_test_pred_nosig= \
                                    sess.run([ae_loss, gan_g_loss,gan_d_loss, Y_pred,Y_pred_modi,Y_pred_nosig],feed_dict={X: X_test_batch, threshold:upper_threshold, Y: Y_test_batch,training:False, w: weight_for})
                                predict_result = np.float32(Y_test_modi > 0.01)
                                predict_result = np.reshape(
                                    predict_result, [
                                        batch_size, input_shape[0],
                                        input_shape[1], input_shape[2]
                                    ])
                                # Foreground
                                # if np.sum(Y_test_batch)>0:
                                #     accuracy_for = np.sum(predict_result*Y_test_batch)/np.sum(Y_test_batch)
                                # Background
                                # accuracy_bac = np.sum((1-predict_result)*(1-Y_test_batch))/(np.sum(1-Y_test_batch))
                                # IOU
                                predict_probablity = np.float32(
                                    (Y_test_modi - 0.01) > 0)
                                predict_probablity = np.reshape(
                                    predict_probablity, [
                                        batch_size, input_shape[0],
                                        input_shape[1], input_shape[2]
                                    ])
                                accuracy = 2 * np.sum(
                                    np.abs(predict_probablity *
                                           Y_test_batch)) / np.sum(
                                               np.abs(predict_result) +
                                               np.abs(Y_test_batch))
                                # if epoch%30==0 and epoch>0:
                                #     to_save = {'X_test': X_test_batch, 'Y_test_pred': Y_test_pred,'Y_test_true': Y_test_batch}
                                #     scipy.io.savemat(self.test_results_dir + 'X_Y_pred_' + str(epoch).zfill(2) + '_' + str(i).zfill(4) + '.mat', to_save, do_compression=True)
                                print "epoch:", epoch, " i:", "\nIOU accuracy: ", accuracy, "\ntest ae loss:", ae_loss_t, " gan g loss:", gan_g_loss_t, " gan d loss:", gan_d_loss_t
                                if accuracy > best_acc:
                                    saver.save(
                                        sess,
                                        save_path=self.train_models_dir +
                                        'model.cptk')
                                    print "epoch:", epoch, " i:", i, "best model saved!"
                                    best_acc = accuracy
                            except Exception, e:
                                print e
                        #### model saving
                        if i % 30 == 0 and epoch % 1 == 0:
                            # regular_train_dir = "./regular/"
                            # if not os.path.exists(regular_train_dir):
                            #     os.makedirs(regular_train_dir)
                            saver.save(sess,
                                       save_path=self.train_models_dir +
                                       'model.cptk')
                            print "epoch:", epoch, " i:", i, "regular model saved!"
                else:
                    print "bad data , next epoch", epoch
コード例 #4
0
    def test(self, dicom_dir):
        flags = self.FLAGS
        block_shape = self.block_shape
        batch_size_test = self.batch_size_test
        data_type = "dicom_data"
        X = tf.placeholder(dtype=tf.float32,
                           shape=[
                               batch_size_test, block_shape[0], block_shape[1],
                               block_shape[2]
                           ])
        training = tf.placeholder(tf.bool)
        artery_pred, artery_sig = self.Dense_Net_Test(X, training,
                                                      flags.batch_size_test,
                                                      flags.accept_threshold)
        artery_pred = tf.reshape(
            artery_pred,
            [batch_size_test, block_shape[0], block_shape[1], block_shape[2]])

        # binary predict mask
        artery_pred_mask = tf.cast((artery_pred > 0.01), tf.float32)

        saver = tf.train.Saver(max_to_keep=1)
        config = tf.ConfigProto(allow_soft_placement=True)
        with tf.Session(config=config) as sess:
            # load variables if saved before
            if len(os.listdir(self.train_models_dir)) > 0:
                print "load saved model"
                sess.run(
                    tf.group(tf.global_variables_initializer(),
                             tf.local_variables_initializer()))
                saver.restore(sess,
                              self.train_models_dir + "train_models.ckpt")
            else:
                print "no model detected from %s" % (self.train_models_dir)
                exit(1)

            test_data = tools.Test_data(dicom_dir, block_shape, data_type)
            test_data.organize_blocks()
            block_numbers = test_data.blocks.keys()
            blocks_num = len(block_numbers)
            print "block count: ", blocks_num
            time1 = time.time()
            sys.stdout.write("\r>>>deep learning calculating : %f" % (0.0) +
                             "%")
            sys.stdout.flush()
            for i in range(0, blocks_num, batch_size_test):
                batch_numbers = []
                if i + batch_size_test < blocks_num:
                    temp_batch_size = batch_size_test
                else:
                    temp_batch_size = blocks_num - i
                temp_input = np.zeros([
                    batch_size_test, block_shape[0], block_shape[1],
                    block_shape[2]
                ])
                for j in range(temp_batch_size):
                    temp_num = block_numbers[i + j]
                    temp_block = test_data.blocks[temp_num]
                    batch_numbers.append(temp_num)
                    block_array = temp_block.load_data()
                    data_block_shape = np.shape(block_array)
                    temp_input[j, 0:data_block_shape[0], 0:data_block_shape[1],
                               0:data_block_shape[2]] += block_array
                artery_predict = sess.run(artery_pred_mask,
                                          feed_dict={
                                              X: temp_input,
                                              training: False
                                          })
                for j in range(temp_batch_size):
                    test_data.upload_result(batch_numbers[j],
                                            artery_predict[j, :, :, :])
                if (i) % (batch_size_test * 10) == 0:
                    sys.stdout.write("\r>>>deep learning calculating : %f" %
                                     ((1.0 * i) * 100 / blocks_num) + "%")
                    sys.stdout.flush()

            sys.stdout.write("\r>>>deep learning calculating : %f" % (100.0) +
                             "%")
            sys.stdout.flush()
            time2 = time.time()
            print "\ndeep learning time consume : ", str(time2 - time1)
            time3 = time.time()
            test_result_array = test_data.get_result()
            print "result shape: ", np.shape(test_result_array)
            r_s = np.shape(test_result_array)  # result shape
            e_t = 10  # edge thickness
            to_be_transformed = np.zeros(r_s, np.float32)
            to_be_transformed[e_t:r_s[0] - e_t, e_t:r_s[1] - e_t, 0:r_s[2] -
                              e_t] += test_result_array[e_t:r_s[0] - e_t,
                                                        e_t:r_s[1] - e_t,
                                                        0:r_s[2] - e_t]
            print "maximum value in mask: ", np.max(to_be_transformed)
            print "minimum value in mask: ", np.min(to_be_transformed)
            final_img = ST.GetImageFromArray(
                np.transpose(np.int8(to_be_transformed), [2, 1, 0]))
            final_img.SetSpacing(test_data.space)
            time4 = time.time()
            print "post processing time consume : ", str(time4 - time3)
            print "writing final testing result"
            print './test_result/test_result_final.vtk'
            ST.WriteImage(final_img, './test_result/test_result_final.vtk')
            return final_img
コード例 #5
0
    def train(self, configure):
        # data
        data = tools.Data(configure, epoch_walked / re_example_epoch)
        # network
        X = tf.placeholder(
            shape=[batch_size, input_shape[0], input_shape[1], input_shape[2]],
            dtype=tf.float32)
        Y = tf.placeholder(shape=[
            batch_size, output_shape[0], output_shape[1], output_shape[2]
        ],
                           dtype=tf.float32)
        lr = tf.placeholder(tf.float32)
        training = tf.placeholder(tf.bool)
        threshold = tf.placeholder(tf.float32)
        with tf.variable_scope('segmenter'):
            Y_pred, Y_pred_modi, Y_pred_nosig = self.ae_u(
                X, training, batch_size, threshold)

        # loss function
        Y_ = tf.reshape(Y, shape=[batch_size, -1])
        Y_pred_modi_ = tf.reshape(Y_pred_modi, shape=[batch_size, -1])
        w = tf.placeholder(tf.float32)  # foreground weight
        cross_loss = tf.reduce_mean(
            -tf.reduce_mean(w * Y_ * tf.log(Y_pred_modi_ + 1e-8),
                            reduction_indices=[1]) -
            tf.reduce_mean((1 - w) *
                           (1 - Y_) * tf.log(1 - Y_pred_modi_ + 1e-8),
                           reduction_indices=[1]))
        loss_sum = tf.summary.scalar("cross entropy", cross_loss)

        # trainers
        optim = tf.train.AdamOptimizer(learning_rate=lr,
                                       beta1=0.9,
                                       beta2=0.999,
                                       epsilon=1e-8).minimize(cross_loss)

        # accuracy
        block_acc = tf.placeholder(tf.float32)
        total_acc = tf.placeholder(tf.float32)
        train_sum = tf.summary.scalar("train_block_accuracy", block_acc)
        test_sum = tf.summary.scalar("total_test_accuracy", total_acc)
        train_merge_op = tf.summary.merge([train_sum, loss_sum])
        test_merge_op = tf.summary.merge([test_sum])

        saver = tf.train.Saver(max_to_keep=1)
        # config = tf.ConfigProto(allow_soft_placement=True)
        # config.gpu_options.visible_device_list = GPU0

        with tf.Session() as sess:
            # define tensorboard writer
            sum_writer_train = tf.summary.FileWriter(self.train_sum_dir,
                                                     sess.graph)
            sum_write_test = tf.summary.FileWriter(self.test_sum_dir,
                                                   sess.graph)
            # load model data if pre-trained
            sess.run(
                tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer()))
            if os.path.isfile(self.train_models_dir +
                              'model.cptk.data-00000-of-00001'):
                print "restoring saved model"
                saver.restore(sess, self.train_models_dir + 'model.cptk')
            learning_rate_g = ori_lr * pow(power, (epoch_walked / decay_step))
            # start training loop
            global_step = step_walked
            for epoch in range(epoch_walked, MAX_EPOCH):
                if epoch % re_example_epoch == 0 and epoch > 0:
                    del data
                    gc.collect()
                    data = tools.Data(configure, epoch / re_example_epoch)
                train_amount = len(data.train_numbers)
                test_amount = len(data.test_numbers)
                if train_amount >= test_amount and train_amount > 0 and test_amount > 0 and data.total_train_batch_num > 0 and data.total_test_seq_batch > 0:
                    # actual foreground weight
                    weight_for = 0.5 + (1 - 1.0 * epoch / MAX_EPOCH) * 0.35
                    if epoch % total_test_epoch == 0 and epoch > 0:
                        print '********************** FULL TESTING ********************************'
                        time_begin = time.time()
                        origin_dir = read_dicoms(test_dir + "original1")
                        mask_dir = test_dir + "artery"
                        test_batch_size = batch_size
                        # test_data = tools.Test_data(dicom_dir,input_shape)
                        test_data = tools.Test_data(origin_dir, input_shape,
                                                    'vtk_data')
                        test_data.organize_blocks()
                        test_mask = read_dicoms(mask_dir)
                        array_mask = ST.GetArrayFromImage(test_mask)
                        array_mask = np.transpose(array_mask, (2, 1, 0))
                        print "mask shape: ", np.shape(array_mask)
                        block_numbers = test_data.blocks.keys()
                        for i in range(0, len(block_numbers), test_batch_size):
                            batch_numbers = []
                            if i + test_batch_size < len(block_numbers):
                                temp_input = np.zeros([
                                    test_batch_size, input_shape[0],
                                    input_shape[1], input_shape[2]
                                ])
                                for j in range(test_batch_size):
                                    temp_num = block_numbers[i + j]
                                    temp_block = test_data.blocks[temp_num]
                                    batch_numbers.append(temp_num)
                                    block_array = temp_block.load_data()
                                    block_shape = np.shape(block_array)
                                    temp_input[j, 0:block_shape[0],
                                               0:block_shape[1],
                                               0:block_shape[2]] += block_array
                                Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run(
                                    [Y_pred, Y_pred_modi, Y_pred_nosig],
                                    feed_dict={
                                        X:
                                        temp_input,
                                        training:
                                        False,
                                        w:
                                        weight_for,
                                        threshold:
                                        upper_threshold + test_extra_threshold
                                    })
                                for j in range(test_batch_size):
                                    test_data.upload_result(
                                        batch_numbers[j],
                                        Y_temp_modi[j, :, :, :])
                            else:
                                temp_batch_size = len(block_numbers) - i
                                temp_input = np.zeros([
                                    temp_batch_size, input_shape[0],
                                    input_shape[1], input_shape[2]
                                ])
                                for j in range(temp_batch_size):
                                    temp_num = block_numbers[i + j]
                                    temp_block = test_data.blocks[temp_num]
                                    batch_numbers.append(temp_num)
                                    block_array = temp_block.load_data()
                                    block_shape = np.shape(block_array)
                                    temp_input[j, 0:block_shape[0],
                                               0:block_shape[1],
                                               0:block_shape[2]] += block_array
                                X_temp = tf.placeholder(shape=[
                                    temp_batch_size, input_shape[0],
                                    input_shape[1], input_shape[2]
                                ],
                                                        dtype=tf.float32)
                                with tf.variable_scope('segmenter',
                                                       reuse=True):
                                    Y_pred_temp, Y_pred_modi_temp, Y_pred_nosig_temp = self.ae_u(
                                        X_temp, training, temp_batch_size,
                                        threshold)
                                Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run(
                                    [
                                        Y_pred_temp, Y_pred_modi_temp,
                                        Y_pred_nosig_temp
                                    ],
                                    feed_dict={
                                        X_temp:
                                        temp_input,
                                        training:
                                        False,
                                        w:
                                        weight_for,
                                        threshold:
                                        upper_threshold + test_extra_threshold
                                    })
                                for j in range(temp_batch_size):
                                    test_data.upload_result(
                                        batch_numbers[j],
                                        Y_temp_modi[j, :, :, :])
                        test_result_array = test_data.get_result()
                        print "result shape: ", np.shape(test_result_array)
                        to_be_transformed = self.post_process(
                            test_result_array)
                        if epoch % output_epoch == 0:
                            self.output_img(to_be_transformed, test_data.space,
                                            epoch)
                        if epoch == 0:
                            mask_img = ST.GetImageFromArray(
                                np.transpose(array_mask, [2, 1, 0]))
                            mask_img.SetSpacing(test_data.space)
                            ST.WriteImage(mask_img,
                                          './test_result/test_mask.vtk')
                        test_IOU = 2 * np.sum(
                            to_be_transformed * array_mask) / (
                                np.sum(to_be_transformed) + np.sum(array_mask))
                        test_summary = sess.run(
                            test_merge_op, feed_dict={total_acc: test_IOU})
                        sum_write_test.add_summary(test_summary,
                                                   global_step=epoch)
                        print "IOU accuracy: ", test_IOU
                        time_end = time.time()
                        print '******************** time of full testing: ' + str(
                            time_end - time_begin) + 's ********************'
                    data.shuffle_X_Y_pairs()
                    total_train_batch_num = data.total_train_batch_num
                    print "total_train_batch_num:", total_train_batch_num
                    for i in range(total_train_batch_num):
                        X_train_batch, Y_train_batch = data.load_X_Y_voxel_train_next_batch(
                        )
                        # calculate loss value
                        # print "calculate begin"
                        loss_c = sess.run(
                            [cross_loss],
                            feed_dict={
                                X: X_train_batch,
                                Y: Y_train_batch,
                                training: False,
                                w: weight_for,
                                threshold: upper_threshold
                            })
                        # print "calculate ended"
                        if epoch % decay_step == 0 and epoch > epoch_walked and i == 0:
                            learning_rate_g = learning_rate_g * power
                        sess.run(
                            [optim],
                            feed_dict={
                                X: X_train_batch,
                                threshold: upper_threshold,
                                Y: Y_train_batch,
                                lr: learning_rate_g,
                                training: True,
                                w: weight_for
                            })
                        # print "training ended"
                        global_step += 1
                        # output some results
                        if i % show_step == 0:
                            print "epoch:", epoch, " i:", i, " train loss:", loss_c, " gan g loss:", learning_rate_g
                        if i % block_test_step == 0 and epoch % 1 == 0:
                            try:
                                X_test_batch, Y_test_batch = data.load_X_Y_voxel_test_next_batch(
                                    fix_sample=False)
                                Y_test_pred, Y_test_modi, Y_test_pred_nosig ,loss_t= \
                                    sess.run([ Y_pred, Y_pred_modi, Y_pred_nosig,cross_loss],
                                             feed_dict={X: X_test_batch,
                                                        threshold: upper_threshold + test_extra_threshold,
                                                        Y: Y_test_batch, training: False, w: weight_for})
                                predict_result = np.float32(Y_test_modi > 0.01)
                                predict_result = np.reshape(
                                    predict_result, [
                                        batch_size, input_shape[0],
                                        input_shape[1], input_shape[2]
                                    ])
                                print np.max(Y_test_pred)
                                print np.min(Y_test_pred)
                                # IOU
                                predict_probablity = np.float32(
                                    (Y_test_modi - 0.01) > 0)
                                predict_probablity = np.reshape(
                                    predict_probablity, [
                                        batch_size, input_shape[0],
                                        input_shape[1], input_shape[2]
                                    ])
                                accuracy = 2 * np.sum(
                                    np.abs(predict_probablity *
                                           Y_test_batch)) / np.sum(
                                               np.abs(predict_result) +
                                               np.abs(Y_test_batch))
                                print "epoch:", epoch, " global step: ", global_step, "\nIOU accuracy: ", accuracy, "\ntest ae loss:", loss_t
                                print "weight of foreground : ", weight_for
                                print "upper threshold of testing", (
                                    upper_threshold + test_extra_threshold)
                                train_summary = sess.run(
                                    train_merge_op,
                                    feed_dict={
                                        block_acc: accuracy,
                                        X: X_test_batch,
                                        threshold:
                                        upper_threshold + test_extra_threshold,
                                        Y: Y_test_batch,
                                        training: False,
                                        w: weight_for
                                    })
                                sum_writer_train.add_summary(
                                    train_summary, global_step=global_step)
                            except Exception, e:
                                print e
                        #### model saving
                        if i % model_save_step == 0 and epoch % 1 == 0:
                            saver.save(sess,
                                       save_path=self.train_models_dir +
                                       'model.cptk')
                            print "epoch:", epoch, " i:", i, "regular model saved!"
                else:
                    print "bad data , next epoch", epoch
コード例 #6
0
import os
import shutil
import tensorflow as tf
import scipy.io
import tools
import numpy as np
import time
import test
import SimpleITK as ST
from dicom_read import read_dicoms
import gc

input_shape = [64, 64, 128]
test_dir = './FU_LI_JUN/'

origin_dir = read_dicoms(test_dir + "original1")
test_data = tools.Test_data(origin_dir, input_shape, 'vtk_data')
test_data.output_origin()
print "end"
コード例 #7
0
    def train(self):
        flags = self.FLAGS
        block_shape = self.block_shape
        record_dir = self.record_dir
        record_dir_test = self.record_dir_test
        batch_size_train = self.batch_size_train
        batch_size_test = self.batch_size_test
        test_step = self.test_step
        threashold = flags.accept_threshold
        LEARNING_RATE_BASE = flags.training_rate_base
        LEARNING_RATE_DECAY = flags.training_rate_decay
        weight_vec = tf.constant([
            flags.airway_weight, flags.artery_weight, flags.back_ground_weight
        ], tf.float32)
        X = tf.placeholder(dtype=tf.float32,
                           shape=[
                               batch_size_train, block_shape[0],
                               block_shape[1], block_shape[2]
                           ])
        training = tf.placeholder(tf.bool)
        with tf.variable_scope('network'):
            seg_pred = self.Dense_Net(X, training, flags.batch_size_train,
                                      flags.accept_threshold)

        # lost function
        '''
        lable vector: [airway,artery,background]
        '''
        lables = tf.placeholder(dtype=tf.float32,
                                shape=[
                                    batch_size_train, block_shape[0],
                                    block_shape[1], block_shape[2], 3
                                ])
        weight_map = tf.reduce_sum(tf.multiply(lables, weight_vec), 4)
        loss_origin = tf.nn.softmax_cross_entropy_with_logits(logits=seg_pred,
                                                              labels=lables)
        loss_weighted = weight_map * loss_origin
        loss = tf.reduce_mean(loss_weighted)
        tf.summary.scalar('loss', loss)

        # accuracy
        # predict_softmax = tf.nn.softmax(seg_pred)
        pred_map = tf.argmax(seg_pred, axis=-1)
        pred_map_bool = tf.equal(pred_map, 1)
        artery_pred_mask = tf.cast(pred_map_bool, tf.float32)
        artery_lable = tf.cast(lables[:, :, :, :, 0], tf.float32)
        artery_acc = 2 * tf.reduce_sum(artery_lable * artery_pred_mask) / (
            tf.reduce_sum(artery_lable + artery_pred_mask))
        tf.summary.scalar('airway_block_acc', artery_acc)

        # data part
        records = ut.get_records(record_dir)
        records_processor = TF_Records(records, block_shape)
        single_blocks = records_processor.read_records()
        queue = tf.RandomShuffleQueue(capacity=8,
                                      min_after_dequeue=4,
                                      dtypes=(
                                          single_blocks['airway'].dtype,
                                          single_blocks['artery'].dtype,
                                          single_blocks['lung'].dtype,
                                          single_blocks['original'].dtype,
                                      ))
        enqueue_op = queue.enqueue((
            single_blocks['airway'],
            single_blocks['artery'],
            single_blocks['lung'],
            single_blocks['original'],
        ))
        (airway_block, artery_block, lung_block,
         original_block) = queue.dequeue()
        qr = tf.train.QueueRunner(queue, [enqueue_op] * 2)

        # test data part
        records_test = ut.get_records(record_dir_test)
        records_processor_test = TF_Records(records_test, block_shape)
        single_blocks_test = records_processor_test.read_records()
        queue_test = tf.RandomShuffleQueue(
            capacity=8,
            min_after_dequeue=4,
            dtypes=(
                single_blocks_test['airway'].dtype,
                single_blocks_test['artery'].dtype,
                single_blocks_test['lung'].dtype,
                single_blocks_test['original'].dtype,
            ))
        enqueue_op_test = queue_test.enqueue((
            single_blocks_test['airway'],
            single_blocks_test['artery'],
            single_blocks_test['lung'],
            single_blocks_test['original'],
        ))
        (airway_block_test, artery_block_test, lung_block_test,
         original_block_test) = queue_test.dequeue()
        qr_test = tf.train.QueueRunner(queue, [enqueue_op_test] * 2)

        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.maximum(
            tf.train.exponential_decay(LEARNING_RATE_BASE,
                                       global_step,
                                       13500 * 5 / flags.batch_size_train,
                                       LEARNING_RATE_DECAY,
                                       staircase=True), 1e-9)
        train_op = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                          beta1=0.9,
                                          beta2=0.999,
                                          epsilon=1e-8).minimize(
                                              loss, global_step)
        # merge operation for tensorboard summary
        merge_summary_op = tf.summary.merge_all()

        saver = tf.train.Saver(max_to_keep=1)
        config = tf.ConfigProto(allow_soft_placement=True)

        with tf.Session(config=config) as sess:

            # load variables if saved before
            if len(os.listdir(self.train_models_dir)) > 0:
                print "load saved model"
                sess.run(
                    tf.group(tf.global_variables_initializer(),
                             tf.local_variables_initializer()))
                saver.restore(sess,
                              self.train_models_dir + "train_models.ckpt")
            else:
                sess.run(
                    tf.group(tf.global_variables_initializer(),
                             tf.local_variables_initializer()))

            # coord for the reading threads
            coord = tf.train.Coordinator()
            enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
            enqueue_threads_test = qr_test.create_threads(sess,
                                                          coord=coord,
                                                          start=True)
            tf.train.start_queue_runners(sess=sess)

            summary_writer_test = tf.summary.FileWriter(
                self.test_sum_dir, sess.graph)
            summary_writer_train = tf.summary.FileWriter(
                self.train_sum_dir, sess.graph)

            # main train loop
            # for i in range(flags.max_iteration_num):
            for i in range(flags.max_iteration_num):
                # organize a batch of data for training
                lable_np = np.zeros([
                    batch_size_train, block_shape[0], block_shape[1],
                    block_shape[2], 3
                ], np.int16)
                original_np = np.zeros([
                    batch_size_train, block_shape[0], block_shape[1],
                    block_shape[2]
                ], np.int16)

                # store values into data block
                for m in range(flags.batch_size_train):
                    '''
                    lable vector: [airway,artery,background]
                    '''
                    artery_data, airway_data, original_data = \
                        sess.run([artery_block, airway_block, original_block])
                    airway_array = airway_data
                    artery_array = artery_data
                    back_ground_array = np.int16((airway_array +
                                                  artery_array) == 0)
                    check_array = airway_array + artery_array + back_ground_array
                    while not np.max(check_array) == np.min(check_array) == 1:
                        artery_data, airway_data, original_data = \
                            sess.run([artery_block, airway_block, original_block])
                        airway_array = airway_data
                        artery_array = artery_data
                        back_ground_array = np.int16((airway_array +
                                                      artery_array) == 0)
                        check_array = airway_array + artery_array + back_ground_array
                    lable_np[m, :, :, :, 0] += airway_array
                    lable_np[m, :, :, :, 1] += artery_array
                    lable_np[m, :, :, :, 2] += back_ground_array
                    original_np[m, :, :, :] += original_data
                train_, step_num = sess.run([train_op, global_step],
                                            feed_dict={
                                                X: original_np,
                                                lables: lable_np,
                                                training: True
                                            })
                if step_num % flags.full_test_step == 0:
                    #     full testing
                    print "****************************full testing******************************"
                    data_type = "dicom_data"
                    test_dicom_dir = '/opt/Multi-Task-data-process/multi_task_data_test/FU_LI_JUN/original1'
                    test_mask_dir = '/opt/Multi-Task-data-process/multi_task_data_test/FU_LI_JUN/artery'
                    test_mask = ut.read_dicoms(test_mask_dir)
                    test_mask_array = np.transpose(
                        ST.GetArrayFromImage(test_mask), [2, 1, 0])
                    test_data = tools.Test_data(test_dicom_dir, block_shape,
                                                data_type)
                    test_data.organize_blocks()
                    block_numbers = test_data.blocks.keys()
                    blocks_num = len(block_numbers)
                    print "block count: ", blocks_num
                    time1 = time.time()
                    sys.stdout.write("\r>>>deep learning calculating : %f" %
                                     (0.0) + "%")
                    sys.stdout.flush()
                    for m in range(0, blocks_num, batch_size_train):
                        batch_numbers = []
                        if m + batch_size_train < blocks_num:
                            temp_batch_size = batch_size_train
                        else:
                            temp_batch_size = blocks_num - m
                        temp_input = np.zeros([
                            batch_size_train, block_shape[0], block_shape[1],
                            block_shape[2]
                        ])
                        for j in range(temp_batch_size):
                            temp_num = block_numbers[m + j]
                            temp_block = test_data.blocks[temp_num]
                            batch_numbers.append(temp_num)
                            block_array = temp_block.load_data()
                            data_block_shape = np.shape(block_array)
                            temp_input[j, 0:data_block_shape[0],
                                       0:data_block_shape[1],
                                       0:data_block_shape[2]] += block_array
                        artery_predict = sess.run(artery_pred_mask,
                                                  feed_dict={
                                                      X: temp_input,
                                                      training: False
                                                  })
                        for j in range(temp_batch_size):
                            test_data.upload_result(batch_numbers[j],
                                                    artery_predict[j, :, :, :])
                        if (m) % (batch_size_train * 10) == 0:
                            sys.stdout.write(
                                "\r>>>deep learning calculating : %f" %
                                ((1.0 * m) * 100 / blocks_num) + "%")
                            sys.stdout.flush()

                    sys.stdout.write("\r>>>deep learning calculating : %f" %
                                     (100.0) + "%")
                    sys.stdout.flush()
                    time2 = time.time()
                    print "\ndeep learning time consume : ", str(time2 - time1)
                    time3 = time.time()
                    test_result_array = test_data.get_result()
                    test_result_array = np.float32(test_result_array >= 2)
                    print "result shape: ", np.shape(test_result_array)
                    r_s = np.shape(test_result_array)  # result shape
                    e_t = 10  # edge thickness
                    to_be_transformed = np.zeros(r_s, np.float32)
                    to_be_transformed[e_t:r_s[0] - e_t, e_t:r_s[1] - e_t,
                                      0:r_s[2] -
                                      e_t] += test_result_array[e_t:r_s[0] -
                                                                e_t,
                                                                e_t:r_s[1] -
                                                                e_t,
                                                                0:r_s[2] - e_t]
                    print "maximum value in mask: ", np.max(to_be_transformed)
                    print "minimum value in mask: ", np.min(to_be_transformed)
                    final_img = ST.GetImageFromArray(
                        np.transpose(to_be_transformed, [2, 1, 0]))
                    final_img.SetSpacing(test_data.space)
                    time4 = time.time()
                    print "post processing time consume : ", str(time4 - time3)
                    print "writing final testing result"
                    if not os.path.exists('./test_result'):
                        os.makedirs('./test_result')
                    print './test_result/test_result_' + str(step_num) + '.vtk'
                    ST.WriteImage(
                        final_img,
                        './test_result/test_result_' + str(step_num) + '.vtk')
                    total_accuracy = 2 * np.sum(
                        1.0 * test_mask_array * to_be_transformed) / np.sum(
                            1.0 * (test_mask_array + to_be_transformed))
                    print "total IOU accuracy : ", total_accuracy
                    if i == 0:
                        mask_img = ST.GetImageFromArray(
                            np.transpose(test_mask_array, [2, 1, 0]))
                        mask_img.SetSpacing(test_data.space)
                        ST.WriteImage(mask_img, './test_result/mask_img.vtk')
                    print "***********************full testing end*******************************"
                if i % 10 == 0:
                    sum_train,\
                    l_val \
                        = sess.run([merge_summary_op,
                                    loss],
                                   feed_dict={X: original_np,
                                              lables: lable_np, training: False})
                    summary_writer_train.add_summary(sum_train,
                                                     global_step=int(step_num))
                    print "train :\nstep %d , loss = %f\n =====================" \
                          % (int(step_num), l_val)
                if i % test_step == 0 and i > 0:
                    lable_np_test = np.zeros([
                        batch_size_train, block_shape[0], block_shape[1],
                        block_shape[2], 3
                    ], np.int16)
                    original_np_test = np.zeros([
                        batch_size_train, block_shape[0], block_shape[1],
                        block_shape[2]
                    ], np.int16)
                    for m in range(flags.batch_size_train):
                        '''
                        lable vector: [airway,artery,background]
                        '''
                        artery_data, airway_data, original_data = \
                            sess.run([artery_block_test, airway_block_test, original_block_test])
                        airway_array = airway_data
                        artery_array = artery_data
                        back_ground_array = np.int16((airway_array +
                                                      artery_array) == 0)
                        check_array = airway_array + artery_array + back_ground_array
                        while not np.max(check_array) == np.min(
                                check_array) == 1:
                            artery_data, airway_data, original_data = \
                                sess.run([artery_block, airway_block, original_block])
                            airway_array = airway_data
                            artery_array = artery_data
                            back_ground_array = np.int16((airway_array +
                                                          artery_array) == 0)
                            check_array = airway_array + artery_array + back_ground_array
                        lable_np_test[m, :, :, :, 0] += airway_array
                        lable_np_test[m, :, :, :, 1] += artery_array
                        lable_np_test[m, :, :, :, 2] += back_ground_array
                        original_np_test[m, :, :, :] += original_data
                    sum_test, accuracy_artery, l_val, predict_array = \
                        sess.run([merge_summary_op,artery_acc,loss,pred_map],
                            feed_dict={X: original_np_test,lables: lable_np_test, training: False})
                    summary_writer_test.add_summary(sum_test,
                                                    global_step=int(step_num))
                    print "\ntest :\nstep %d , artery loss = %f \n\t artery block accuracy = %f\n=====================" \
                          % (int(step_num), l_val, accuracy_artery)
                    print "artery percentage : ", str(
                        np.float32(
                            np.sum(np.float32(lable_np_test[:, :, :, :, 1])) /
                            (flags.batch_size_train * block_shape[0] *
                             block_shape[1] * block_shape[2])))
                    # print "prediction of airway : maximum = ",np.max(airway_np_sig)," minimum = ",np.min(airway_np_sig)
                    print "prediction : maximum = ", np.max(
                        predict_array), " minimum = ", np.min(predict_array)
                if i % 100 == 0:
                    saver.save(sess,
                               self.train_models_dir + "train_models.ckpt")
                    print "regular model saved! step count : ", step_num

            coord.request_stop()
            coord.join(enqueue_threads)
            coord.join(enqueue_threads_test)
コード例 #8
0
 def full_testing(self, sess, epoch):
     print '********************** FULL TESTING ********************************'
     time_begin = time.time()
     origin_data = read_dicoms(test_dir + "original1")
     mask_dir = test_dir + "artery"
     test_batch_size = batch_size
     test_data = tools.Test_data(origin_data, input_shape, 'vtk_data')
     test_data.organize_blocks()
     test_mask = read_dicoms(mask_dir)
     array_mask = ST.GetArrayFromImage(test_mask)
     array_mask = np.transpose(array_mask, (2, 1, 0))
     print "mask shape: ", np.shape(array_mask)
     block_numbers = test_data.blocks.keys()
     for i in range(0, len(block_numbers), test_batch_size):
         batch_numbers = []
         if i + test_batch_size < len(block_numbers):
             temp_input = np.zeros([
                 test_batch_size, input_shape[0], input_shape[1],
                 input_shape[2]
             ])
             for j in range(test_batch_size):
                 temp_num = block_numbers[i + j]
                 temp_block = test_data.blocks[temp_num]
                 batch_numbers.append(temp_num)
                 block_array = temp_block.load_data()
                 block_shape = np.shape(block_array)
                 temp_input[j, 0:block_shape[0], 0:block_shape[1],
                            0:block_shape[2]] += block_array
             pred_unsoft, softmax_pred, argmax_label = \
             sess.run([self.pred_unsoft, self.softmax_pred, self.argmax_label],
                 feed_dict={self.X: temp_input,
                            self.training: False})
             for j in range(test_batch_size):
                 test_data.upload_result_multiclass(
                     batch_numbers[j], argmax_label[j, :, :, :], mask_type)
         else:
             temp_batch_size = len(block_numbers) - i
             temp_input = np.zeros([
                 temp_batch_size, input_shape[0], input_shape[1],
                 input_shape[2]
             ])
             for j in range(temp_batch_size):
                 temp_num = block_numbers[i + j]
                 temp_block = test_data.blocks[temp_num]
                 batch_numbers.append(temp_num)
                 block_array = temp_block.load_data()
                 block_shape = np.shape(block_array)
                 temp_input[j, 0:block_shape[0], 0:block_shape[1],
                            0:block_shape[2]] += block_array
             X_temp = tf.placeholder(shape=[
                 temp_batch_size, input_shape[0], input_shape[1],
                 input_shape[2]
             ],
                                     dtype=tf.float32)
             with tf.variable_scope("segment", reuse=True):
                 temp_unsoft, softmax_temp, argmax_temp = self.Segmentor(
                     X_temp, self.training, batch_size)
                 pred_unsoft_temp, softmax_pred_temp, argmax_label_temp = \
                     sess.run([temp_unsoft, softmax_temp,argmax_temp],feed_dict={X_temp:temp_input,
                                                                                 self.training:False})
                 for j in range(temp_batch_size):
                     test_data.upload_result_multiclass(
                         batch_numbers[j], argmax_label_temp[j, :, :, :],
                         mask_type)
     test_result_array = test_data.get_result_()
     print "result shape: ", np.shape(test_result_array)
     to_be_transformed = self.post_process(test_result_array)
     if epoch == 0:
         mask_img = ST.GetImageFromArray(np.transpose(
             array_mask, [2, 1, 0]))
         mask_img.SetSpacing(test_data.space)
         ST.WriteImage(mask_img, './test_result/test_mask.vtk')
     test_IOU = 2 * np.sum(to_be_transformed * array_mask) / (
         np.sum(to_be_transformed) + np.sum(array_mask))
     test_summary = sess.run(self.test_merge_op,
                             feed_dict={self.total_acc: test_IOU})
     self.sum_write_test.add_summary(test_summary, global_step=epoch)
     print "IOU accuracy: ", test_IOU
     time_end = time.time()
     print '******************** time of full testing: ' + str(
         time_end - time_begin) + 's ********************'