Ejemplo n.º 1
0
    def demo(self):

        d = tools.Data(config)

        if not os.path.exists(self.demo_dir + 'depth/'):
            print('Demo depth folder not present!!!')
            return

        filenames = glob.glob(self.demo_dir + 'depth/*')

        if len(filenames) == 0:
            print('No files found in depth folder!!')
            return

        if not os.path.exists(self.demo_dir + 'voxel/'):
            os.makedirs(self.demo_dir + 'voxel/')

        if len(filenames) % self.batch_size != 0:
            print('Number of images should be a multiple of batch size ({})'.
                  format(self.batch_size))
            return

        for i in range(len(filenames) // self.batch_size):
            X_data_files = filenames[self.batch_size * i:self.batch_size *
                                     (i + 1)]

            X_test_batch = d.load_X_Y_voxel_grids(X_data_files)

            Y_pred_batch = self.sess.run(self.Y_pred,
                                         feed_dict={self.X: X_test_batch})

            for i, filename in enumerate(X_data_files):
                np.save(
                    filename.replace('/depth/',
                                     '/voxel/').replace('.png', '.npy'),
                    Y_pred_batch[i, :, :, :, :])
Ejemplo n.º 2
0
                        ae_loss_t,gan_g_loss_t,gan_d_loss_t, Y_test_pred, Y_test_pred_nosig= \
                            sess.run([ae_loss, gan_g_loss,gan_d_loss, Y_pred,Y_pred_nosig],feed_dict={X: X_test_batch, Y: Y_test_batch})

                        to_save = {
                            'X_test': X_test_batch,
                            'Y_test_pred': Y_test_pred,
                            'Y_test_true': Y_test_batch
                        }
                        scipy.io.savemat(self.test_results_dir + 'X_Y_pred_' +
                                         str(epoch).zfill(2) + '_' +
                                         str(i).zfill(4) + '.mat',
                                         to_save,
                                         do_compression=True)
                        print "epoch:", epoch, " i:", i, " test ae loss:", ae_loss_t, " gan g loss:", gan_g_loss_t, " gan d loss:", gan_d_loss_t

                    #### full testing
                    # ...

                    #### model saving
                    if i % 500 == 0 and i > 0 and epoch % 1 == 0:
                        saver.save(sess,
                                   save_path=self.train_models_dir +
                                   'model.cptk')
                        print "epoch:", epoch, " i:", i, " model saved!"


if __name__ == "__main__":
    data = tools.Data(config)
    net = Network()
    net.train(data)
    def train(self, configure):
        data = tools.Data(configure, epoch_walked)
        best_acc = 0
        # X = tf.placeholder(shape=[batch_size, input_shape[0], input_shape[1], input_shape[2]], dtype=tf.float32)
        X = tf.placeholder(
            shape=[batch_size, input_shape[0], input_shape[1], input_shape[2]],
            dtype=tf.float32)
        # Y = tf.placeholder(shape=[batch_size, output_shape[0], output_shape[1], output_shape[2]], dtype=tf.float32)
        Y = tf.placeholder(shape=[
            batch_size, output_shape[0], output_shape[1], output_shape[2]
        ],
                           dtype=tf.float32)
        print X.get_shape()
        lr = tf.placeholder(tf.float32)
        training = tf.placeholder(tf.bool)
        threshold = tf.placeholder(tf.float32)
        with tf.variable_scope('ae'):
            Y_pred, Y_pred_modi, Y_pred_nosig = self.ae_u(
                X, training, batch_size, threshold)

        with tf.variable_scope('dis'):
            XY_real_pair = self.dis(X, Y, training)
        with tf.variable_scope('dis', reuse=True):
            XY_fake_pair = self.dis(X, Y_pred, training)

        with tf.device('/gpu:' + GPU0):
            ################################ ae loss
            Y_ = tf.reshape(Y, shape=[batch_size, -1])
            Y_pred_modi_ = tf.reshape(Y_pred_modi, shape=[batch_size, -1])
            w = tf.placeholder(
                tf.float32)  # power of foreground against background
            ae_loss = tf.reduce_mean(
                -tf.reduce_mean(w * Y_ * tf.log(Y_pred_modi_ + 1e-8),
                                reduction_indices=[1]) -
                tf.reduce_mean((1 - w) *
                               (1 - Y_) * tf.log(1 - Y_pred_modi_ + 1e-8),
                               reduction_indices=[1]))
            sum_ae_loss = tf.summary.scalar('ae_loss', ae_loss)

            ################################ wgan loss
            gan_g_loss = -tf.reduce_mean(XY_fake_pair)
            gan_d_loss = tf.reduce_mean(XY_fake_pair) - tf.reduce_mean(
                XY_real_pair)
            sum_gan_g_loss = tf.summary.scalar('gan_g_loss', gan_g_loss)
            sum_gan_d_loss = tf.summary.scalar('gan_d_loss', gan_d_loss)
            alpha = tf.random_uniform(shape=[
                batch_size, input_shape[0] * input_shape[1] * input_shape[2]
            ],
                                      minval=0.0,
                                      maxval=1.0)

            Y_pred_ = tf.reshape(Y_pred, shape=[batch_size, -1])
            differences_ = Y_pred_ - Y_
            interpolates = Y_ + alpha * differences_
            with tf.variable_scope('dis', reuse=True):
                XY_fake_intep = self.dis(X, interpolates, training)
            gradients = tf.gradients(XY_fake_intep, [interpolates])[0]
            slopes = tf.sqrt(
                tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
            gradient_penalty = tf.reduce_mean((slopes - 1.0)**2)
            gan_d_loss += 10 * gradient_penalty

            #################################  ae + gan loss
            gan_g_w = 5
            ae_w = 100 - gan_g_w
            ae_gan_g_loss = ae_w * ae_loss + gan_g_w * gan_g_loss

        with tf.device('/gpu:' + GPU0):
            ae_var = [
                var for var in tf.trainable_variables()
                if var.name.startswith('ae')
            ]
            dis_var = [
                var for var in tf.trainable_variables()
                if var.name.startswith('dis')
            ]
            ae_g_optim = tf.train.AdamOptimizer(learning_rate=lr,
                                                beta1=0.9,
                                                beta2=0.999,
                                                epsilon=1e-8).minimize(
                                                    ae_gan_g_loss,
                                                    var_list=ae_var)
            dis_optim = tf.train.AdamOptimizer(learning_rate=lr,
                                               beta1=0.9,
                                               beta2=0.999,
                                               epsilon=1e-8).minimize(
                                                   gan_d_loss,
                                                   var_list=dis_var)

        print tools.Ops.variable_count()
        sum_merged = tf.summary.merge_all()

        saver = tf.train.Saver(max_to_keep=1)
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.visible_device_list = GPU0
        with tf.Session(config=config) as sess:
            # if os.path.exists(self.train_models_dir):
            #     try:
            #         saver.restore(sess,self.train_models_dir+'model.cptk')
            #     except Exception,e:
            #         saver.restore(sess,'./regular/'+'model.cptk')
            sum_writer_train = tf.summary.FileWriter(self.train_sum_dir,
                                                     sess.graph)
            sum_write_test = tf.summary.FileWriter(self.test_sum_dir)

            if os.path.isfile(self.train_models_dir +
                              'model.cptk.data-00000-of-00001'):
                print "restoring saved model"
                saver.restore(sess, self.train_models_dir + 'model.cptk')
            else:
                sess.run(tf.global_variables_initializer())

            learning_rate_g = ori_lr * pow(power, (epoch_walked / 4))
            for epoch in range(epoch_walked, 15000):
                # data.shuffle_X_Y_files(label='train')
                #### select data randomly each 10 epochs
                if epoch % 2 == 0 and epoch > 0:
                    del data
                    gc.collect()
                    data = tools.Data(configure, epoch)
                #### full testing
                # ...
                train_amount = len(data.train_numbers)
                test_amount = len(data.test_numbers)
                if train_amount >= test_amount and train_amount > 0 and test_amount > 0 and data.total_train_batch_num > 0 and data.total_test_seq_batch > 0:
                    weight_for = 0.35 * (1 - epoch * 1.0 / 15000) + 0.5
                    if epoch % 4 == 0:
                        print '********************** FULL TESTING ********************************'
                        time_begin = time.time()
                        lung_img = ST.ReadImage('./WANG_REN/lung_img.vtk')
                        mask_dir = "./WANG_REN/airway"
                        test_batch_size = batch_size
                        # test_data = tools.Test_data(dicom_dir,input_shape)
                        test_data = tools.Test_data(lung_img, input_shape,
                                                    'vtk_data')
                        test_data.organize_blocks()
                        test_mask = read_dicoms(mask_dir)
                        array_mask = ST.GetArrayFromImage(test_mask)
                        array_mask = np.transpose(array_mask, (2, 1, 0))
                        print "mask shape: ", np.shape(array_mask)
                        time1 = time.time()
                        block_numbers = test_data.blocks.keys()
                        for i in range(0, len(block_numbers), test_batch_size):
                            batch_numbers = []
                            if i + test_batch_size < len(block_numbers):
                                temp_input = np.zeros([
                                    test_batch_size, input_shape[0],
                                    input_shape[1], input_shape[2]
                                ])
                                for j in range(test_batch_size):
                                    temp_num = block_numbers[i + j]
                                    temp_block = test_data.blocks[temp_num]
                                    batch_numbers.append(temp_num)
                                    block_array = temp_block.load_data()
                                    block_shape = np.shape(block_array)
                                    temp_input[j, 0:block_shape[0],
                                               0:block_shape[1],
                                               0:block_shape[2]] += block_array
                                Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run(
                                    [Y_pred, Y_pred_modi, Y_pred_nosig],
                                    feed_dict={
                                        X: temp_input,
                                        training: False,
                                        w: weight_for,
                                        threshold: upper_threshold
                                    })
                                for j in range(test_batch_size):
                                    test_data.upload_result(
                                        batch_numbers[j],
                                        Y_temp_modi[j, :, :, :])
                            else:
                                temp_batch_size = len(block_numbers) - i
                                temp_input = np.zeros([
                                    temp_batch_size, input_shape[0],
                                    input_shape[1], input_shape[2]
                                ])
                                for j in range(temp_batch_size):
                                    temp_num = block_numbers[i + j]
                                    temp_block = test_data.blocks[temp_num]
                                    batch_numbers.append(temp_num)
                                    block_array = temp_block.load_data()
                                    block_shape = np.shape(block_array)
                                    temp_input[j, 0:block_shape[0],
                                               0:block_shape[1],
                                               0:block_shape[2]] += block_array
                                X_temp = tf.placeholder(shape=[
                                    temp_batch_size, input_shape[0],
                                    input_shape[1], input_shape[2]
                                ],
                                                        dtype=tf.float32)
                                with tf.variable_scope('ae', reuse=True):
                                    Y_pred_temp, Y_pred_modi_temp, Y_pred_nosig_temp = self.ae_u(
                                        X_temp, training, temp_batch_size,
                                        threshold)
                                Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run(
                                    [
                                        Y_pred_temp, Y_pred_modi_temp,
                                        Y_pred_nosig_temp
                                    ],
                                    feed_dict={
                                        X_temp: temp_input,
                                        training: False,
                                        w: weight_for,
                                        threshold: upper_threshold
                                    })
                                for j in range(temp_batch_size):
                                    test_data.upload_result(
                                        batch_numbers[j],
                                        Y_temp_modi[j, :, :, :])
                        test_result_array = test_data.get_result()
                        print "result shape: ", np.shape(test_result_array)
                        r_s = np.shape(test_result_array)  # result shape
                        e_t = 10  # edge thickness
                        to_be_transformed = np.zeros(r_s, np.float32)
                        to_be_transformed[
                            e_t:r_s[0] - e_t, e_t:r_s[1] - e_t, e_t:r_s[2] -
                            e_t] += test_result_array[e_t:r_s[0] - e_t,
                                                      e_t:r_s[1] - e_t,
                                                      e_t:r_s[2] - e_t]
                        print np.max(to_be_transformed)
                        print np.min(to_be_transformed)
                        final_img = ST.GetImageFromArray(
                            np.transpose(to_be_transformed, [2, 1, 0]))
                        final_img.SetSpacing(test_data.space)
                        print "writing full testing result"
                        print '/usr/analyse_airway/test_result/test_result' + str(
                            epoch) + '.vtk'
                        ST.WriteImage(
                            final_img,
                            '/usr/analyse_airway/test_result/test_result' +
                            str(epoch) + '.vtk')
                        if epoch == 0:
                            mask_img = ST.GetImageFromArray(
                                np.transpose(array_mask, [2, 1, 0]))
                            mask_img.SetSpacing(test_data.space)
                            ST.WriteImage(
                                mask_img,
                                '/usr/analyse_airway/test_result/test_mask.vtk'
                            )
                        test_IOU = 2 * np.sum(
                            to_be_transformed * array_mask) / (
                                np.sum(to_be_transformed) + np.sum(array_mask))
                        print "IOU accuracy: ", test_IOU
                        time_end = time.time()
                        print '******************** time of full testing: ' + str(
                            time_end - time_begin) + 's ********************'
                    data.shuffle_X_Y_pairs()
                    total_train_batch_num = data.total_train_batch_num
                    # train_files=data.X_train_files
                    # test_files=data.X_test_files
                    # total_train_batch_num = 500
                    print "total_train_batch_num:", total_train_batch_num
                    for i in range(total_train_batch_num):

                        #### training
                        X_train_batch, Y_train_batch = data.load_X_Y_voxel_train_next_batch(
                        )
                        # X_train_batch, Y_train_batch = data.load_X_Y_voxel_grids_train_next_batch()
                        # Y_train_batch=np.reshape(Y_train_batch,[batch_size, output_shape[0], output_shape[1], output_shape[2], 1])
                        gan_d_loss_c, = sess.run(
                            [gan_d_loss],
                            feed_dict={
                                X: X_train_batch,
                                Y: Y_train_batch,
                                training: False,
                                w: weight_for,
                                threshold: upper_threshold
                            })
                        ae_loss_c, gan_g_loss_c, sum_train = sess.run(
                            [ae_loss, gan_g_loss, sum_merged],
                            feed_dict={
                                X: X_train_batch,
                                Y: Y_train_batch,
                                training: False,
                                w: weight_for,
                                threshold: upper_threshold
                            })
                        if epoch % 4 == 0 and epoch > 0 and i == 0:
                            learning_rate_g = learning_rate_g * power
                        sess.run(
                            [ae_g_optim],
                            feed_dict={
                                X: X_train_batch,
                                threshold: upper_threshold,
                                Y: Y_train_batch,
                                lr: learning_rate_g,
                                training: True,
                                w: weight_for
                            })
                        if epoch <= 5:
                            sess.run(
                                [dis_optim],
                                feed_dict={
                                    X: X_train_batch,
                                    threshold: upper_threshold,
                                    Y: Y_train_batch,
                                    lr: learning_rate_g,
                                    training: True,
                                    w: weight_for
                                })
                        elif epoch <= 20:
                            sess.run(
                                [dis_optim],
                                feed_dict={
                                    X: X_train_batch,
                                    threshold: upper_threshold,
                                    Y: Y_train_batch,
                                    lr: learning_rate_g,
                                    training: True,
                                    w: weight_for
                                })
                        else:
                            sess.run(
                                [dis_optim],
                                feed_dict={
                                    X: X_train_batch,
                                    threshold: upper_threshold,
                                    Y: Y_train_batch,
                                    lr: learning_rate_g,
                                    training: True,
                                    w: weight_for
                                })

                        sum_writer_train.add_summary(
                            sum_train, epoch * total_train_batch_num + i)
                        if i % 2 == 0:
                            print "epoch:", epoch, " i:", i, " train ae loss:", ae_loss_c, " gan g loss:", gan_g_loss_c, " gan d loss:", gan_d_loss_c, " learning rate: ", learning_rate_g
                        #### testing
                        if i % 20 == 0 and epoch % 1 == 0:
                            try:
                                X_test_batch, Y_test_batch = data.load_X_Y_voxel_test_next_batch(
                                    fix_sample=False)
                                # Y_test_batch = np.reshape(Y_test_batch,[batch_size, output_shape[0], output_shape[1], output_shape[2], 1])
                                ae_loss_t,gan_g_loss_t,gan_d_loss_t, Y_test_pred,Y_test_modi, Y_test_pred_nosig= \
                                    sess.run([ae_loss, gan_g_loss,gan_d_loss, Y_pred,Y_pred_modi,Y_pred_nosig],feed_dict={X: X_test_batch, threshold:upper_threshold, Y: Y_test_batch,training:False, w: weight_for})
                                predict_result = np.float32(Y_test_modi > 0.01)
                                predict_result = np.reshape(
                                    predict_result, [
                                        batch_size, input_shape[0],
                                        input_shape[1], input_shape[2]
                                    ])
                                # Foreground
                                # if np.sum(Y_test_batch)>0:
                                #     accuracy_for = np.sum(predict_result*Y_test_batch)/np.sum(Y_test_batch)
                                # Background
                                # accuracy_bac = np.sum((1-predict_result)*(1-Y_test_batch))/(np.sum(1-Y_test_batch))
                                # IOU
                                predict_probablity = np.float32(
                                    (Y_test_modi - 0.01) > 0)
                                predict_probablity = np.reshape(
                                    predict_probablity, [
                                        batch_size, input_shape[0],
                                        input_shape[1], input_shape[2]
                                    ])
                                accuracy = 2 * np.sum(
                                    np.abs(predict_probablity *
                                           Y_test_batch)) / np.sum(
                                               np.abs(predict_result) +
                                               np.abs(Y_test_batch))
                                # if epoch%30==0 and epoch>0:
                                #     to_save = {'X_test': X_test_batch, 'Y_test_pred': Y_test_pred,'Y_test_true': Y_test_batch}
                                #     scipy.io.savemat(self.test_results_dir + 'X_Y_pred_' + str(epoch).zfill(2) + '_' + str(i).zfill(4) + '.mat', to_save, do_compression=True)
                                print "epoch:", epoch, " i:", "\nIOU accuracy: ", accuracy, "\ntest ae loss:", ae_loss_t, " gan g loss:", gan_g_loss_t, " gan d loss:", gan_d_loss_t
                                if accuracy > best_acc:
                                    saver.save(
                                        sess,
                                        save_path=self.train_models_dir +
                                        'model.cptk')
                                    print "epoch:", epoch, " i:", i, "best model saved!"
                                    best_acc = accuracy
                            except Exception, e:
                                print e
                        #### model saving
                        if i % 30 == 0 and epoch % 1 == 0:
                            # regular_train_dir = "./regular/"
                            # if not os.path.exists(regular_train_dir):
                            #     os.makedirs(regular_train_dir)
                            saver.save(sess,
                                       save_path=self.train_models_dir +
                                       'model.cptk')
                            print "epoch:", epoch, " i:", i, "regular model saved!"
                else:
                    print "bad data , next epoch", epoch
Ejemplo n.º 4
0
    def train(self, configure):
        # data
        data = tools.Data(configure, epoch_walked / re_example_epoch)
        # network
        X = tf.placeholder(
            shape=[batch_size, input_shape[0], input_shape[1], input_shape[2]],
            dtype=tf.float32)
        Y = tf.placeholder(shape=[
            batch_size, output_shape[0], output_shape[1], output_shape[2]
        ],
                           dtype=tf.float32)
        lr = tf.placeholder(tf.float32)
        training = tf.placeholder(tf.bool)
        threshold = tf.placeholder(tf.float32)
        with tf.variable_scope('generator'):
            Y_pred, Y_pred_modi, Y_pred_nosig = self.ae_u(
                X, training, batch_size, threshold)
        with tf.variable_scope('discriminator'):
            XY_real_pair = self.dis(X, Y, training)
        with tf.variable_scope('discriminator', reuse=True):
            XY_fake_pair = self.dis(X, Y_pred, training)

        # loss function
        # generator loss
        Y_ = tf.reshape(Y, shape=[batch_size, -1])
        Y_pred_modi_ = tf.reshape(Y_pred_modi, shape=[batch_size, -1])
        w = tf.placeholder(tf.float32)  # foreground weight
        g_loss = tf.reduce_mean(-tf.reduce_mean(
            w * Y_ * tf.log(Y_pred_modi_ + 1e-8), reduction_indices=[1]) -
                                tf.reduce_mean((1 - w) * (1 - Y_) *
                                               tf.log(1 - Y_pred_modi_ + 1e-8),
                                               reduction_indices=[1]))
        g_loss_sum = tf.summary.scalar("generator cross entropy", g_loss)
        # discriminator loss
        gan_d_loss = tf.reduce_mean(XY_fake_pair) - tf.reduce_mean(
            XY_real_pair)
        alpha = tf.random_uniform(shape=[
            batch_size, input_shape[0] * input_shape[1] * input_shape[2]
        ],
                                  minval=0.0,
                                  maxval=1.0)
        Y_pred_ = tf.reshape(Y_pred, shape=[batch_size, -1])
        differences_ = Y_pred_ - Y_
        interpolates = Y_ + alpha * differences_
        with tf.variable_scope('discriminator', reuse=True):
            XY_fake_intep = self.dis(X, interpolates, training)
        gradients = tf.gradients(XY_fake_intep, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        gradient_penalty = tf.reduce_mean((slopes - 1.0)**2)
        gan_d_loss += 10 * gradient_penalty
        gan_d_loss_sum = tf.summary.scalar("total loss of discriminator",
                                           gan_d_loss)

        # generator loss with gan loss
        gan_g_loss = -tf.reduce_mean(XY_fake_pair)
        gan_g_w = 5
        ae_w = 100 - gan_g_w
        ae_gan_g_loss = ae_w * g_loss + gan_g_w * gan_g_loss
        ae_g_loss_sum = tf.summary.scalar("total loss of generator",
                                          ae_gan_g_loss)

        # trainers
        ae_var = [
            var for var in tf.trainable_variables()
            if var.name.startswith('generator')
        ]
        dis_var = [
            var for var in tf.trainable_variables()
            if var.name.startswith('discriminator')
        ]
        ae_g_optim = tf.train.AdamOptimizer(learning_rate=lr,
                                            beta1=0.9,
                                            beta2=0.999,
                                            epsilon=1e-8).minimize(
                                                ae_gan_g_loss, var_list=ae_var)
        dis_optim = tf.train.AdamOptimizer(learning_rate=lr,
                                           beta1=0.9,
                                           beta2=0.999,
                                           epsilon=1e-8).minimize(
                                               gan_d_loss, var_list=dis_var)

        # accuracy
        block_acc = tf.placeholder(tf.float32)
        total_acc = tf.placeholder(tf.float32)
        train_sum = tf.summary.scalar("train_block_accuracy", block_acc)
        test_sum = tf.summary.scalar("total_test_accuracy", total_acc)
        train_merge_op = tf.summary.merge(
            [train_sum, ae_g_loss_sum, gan_d_loss_sum, g_loss_sum])
        test_merge_op = tf.summary.merge([test_sum])

        saver = tf.train.Saver(max_to_keep=1)
        # config = tf.ConfigProto(allow_soft_placement=True)
        # config.gpu_options.visible_device_list = GPU0

        with tf.Session() as sess:
            # define tensorboard writer
            sum_writer_train = tf.summary.FileWriter(self.train_sum_dir,
                                                     sess.graph)
            sum_write_test = tf.summary.FileWriter(self.test_sum_dir,
                                                   sess.graph)
            # load model data if pre-trained
            sess.run(
                tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer()))
            if os.path.isfile(self.train_models_dir +
                              'model.cptk.data-00000-of-00001'):
                print "restoring saved model"
                saver.restore(sess, self.train_models_dir + 'model.cptk')
            learning_rate_g = ori_lr * pow(power, (epoch_walked / decay_step))
            # start training loop
            global_step = step_walked
            for epoch in range(epoch_walked, MAX_EPOCH):
                if epoch % re_example_epoch == 0 and epoch > 0:
                    del data
                    gc.collect()
                    data = tools.Data(configure, epoch / re_example_epoch)
                train_amount = len(data.train_numbers)
                test_amount = len(data.test_numbers)
                if train_amount >= test_amount and train_amount > 0 and test_amount > 0 and data.total_train_batch_num > 0 and data.total_test_seq_batch > 0:
                    # actual foreground weight
                    weight_for = 0.5 + (1 - 1.0 * epoch / MAX_EPOCH) * 0.35
                    if epoch % total_test_epoch == 0:
                        self.full_testing(sess, X, w, threshold, test_merge_op,
                                          sum_write_test, training, weight_for,
                                          total_acc, Y_pred, Y_pred_modi,
                                          Y_pred_nosig, epoch)
                        # print '********************** FULL TESTING ********************************'
                        # time_begin = time.time()
                        # origin_dir = read_dicoms(test_dir + "original1")
                        # mask_dir = test_dir + "airway"
                        # test_batch_size = batch_size
                        # # test_data = tools.Test_data(dicom_dir,input_shape)
                        # test_data = tools.Test_data(origin_dir, input_shape, 'vtk_data')
                        # test_data.organize_blocks()
                        # test_mask = read_dicoms(mask_dir)
                        # array_mask = ST.GetArrayFromImage(test_mask)
                        # array_mask = np.transpose(array_mask, (2, 1, 0))
                        # print "mask shape: ", np.shape(array_mask)
                        # block_numbers = test_data.blocks.keys()
                        # for i in range(0, len(block_numbers), test_batch_size):
                        #     batch_numbers = []
                        #     if i + test_batch_size < len(block_numbers):
                        #         temp_input = np.zeros(
                        #             [test_batch_size, input_shape[0], input_shape[1], input_shape[2]])
                        #         for j in range(test_batch_size):
                        #             temp_num = block_numbers[i + j]
                        #             temp_block = test_data.blocks[temp_num]
                        #             batch_numbers.append(temp_num)
                        #             block_array = temp_block.load_data()
                        #             block_shape = np.shape(block_array)
                        #             temp_input[j, 0:block_shape[0], 0:block_shape[1],
                        #             0:block_shape[2]] += block_array
                        #         Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run(
                        #             [Y_pred, Y_pred_modi, Y_pred_nosig],
                        #             feed_dict={X: temp_input,
                        #                        training: False,
                        #                        w: weight_for,
                        #                        threshold: upper_threshold + test_extra_threshold})
                        #         for j in range(test_batch_size):
                        #             test_data.upload_result(batch_numbers[j], Y_temp_modi[j, :, :, :])
                        #     else:
                        #         temp_batch_size = len(block_numbers) - i
                        #         temp_input = np.zeros(
                        #             [temp_batch_size, input_shape[0], input_shape[1], input_shape[2]])
                        #         for j in range(temp_batch_size):
                        #             temp_num = block_numbers[i + j]
                        #             temp_block = test_data.blocks[temp_num]
                        #             batch_numbers.append(temp_num)
                        #             block_array = temp_block.load_data()
                        #             block_shape = np.shape(block_array)
                        #             temp_input[j, 0:block_shape[0], 0:block_shape[1],
                        #             0:block_shape[2]] += block_array
                        #         X_temp = tf.placeholder(
                        #             shape=[temp_batch_size, input_shape[0], input_shape[1], input_shape[2]],
                        #             dtype=tf.float32)
                        #         with tf.variable_scope('generator', reuse=True):
                        #             Y_pred_temp, Y_pred_modi_temp, Y_pred_nosig_temp = self.ae_u(X_temp, training,
                        #                                                                          temp_batch_size,
                        #                                                                          threshold)
                        #         Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run(
                        #             [Y_pred_temp, Y_pred_modi_temp, Y_pred_nosig_temp],
                        #             feed_dict={X_temp: temp_input,
                        #                        training: False,
                        #                        w: weight_for,
                        #                        threshold: upper_threshold + test_extra_threshold})
                        #         for j in range(temp_batch_size):
                        #             test_data.upload_result(batch_numbers[j], Y_temp_modi[j, :, :, :])
                        # test_result_array = test_data.get_result()
                        # print "result shape: ", np.shape(test_result_array)
                        # to_be_transformed = self.post_process(test_result_array)
                        # if epoch % output_epoch == 0:
                        #     self.output_img(to_be_transformed, test_data.space, epoch)
                        # if epoch == 0:
                        #     mask_img = ST.GetImageFromArray(np.transpose(array_mask, [2, 1, 0]))
                        #     mask_img.SetSpacing(test_data.space)
                        #     ST.WriteImage(mask_img, './test_result/test_mask.vtk')
                        # test_IOU = 2 * np.sum(to_be_transformed * array_mask) / (
                        #         np.sum(to_be_transformed) + np.sum(array_mask))
                        # test_summary = sess.run(test_merge_op, feed_dict={total_acc: test_IOU})
                        # sum_write_test.add_summary(test_summary, global_step=epoch)
                        # print "IOU accuracy: ", test_IOU
                        # time_end = time.time()
                        # print '******************** time of full testing: ' + str(time_end - time_begin) + 's ********************'
                    data.shuffle_X_Y_pairs()
                    total_train_batch_num = data.total_train_batch_num
                    print "total_train_batch_num:", total_train_batch_num
                    for i in range(total_train_batch_num):
                        X_train_batch, Y_train_batch = data.load_X_Y_voxel_train_next_batch(
                        )
                        # calculate loss value
                        # print "calculate begin"
                        gan_d_loss_c, = sess.run(
                            [gan_d_loss],
                            feed_dict={
                                X: X_train_batch,
                                Y: Y_train_batch,
                                training: False,
                                w: weight_for,
                                threshold: upper_threshold
                            })
                        g_loss_c, gan_g_loss_c = sess.run(
                            [g_loss, ae_gan_g_loss],
                            feed_dict={
                                X: X_train_batch,
                                Y: Y_train_batch,
                                training: False,
                                w: weight_for,
                                threshold: upper_threshold
                            })
                        # print "calculate ended"
                        if epoch % decay_step == 0 and epoch > epoch_walked and i == 0:
                            learning_rate_g = learning_rate_g * power
                        sess.run(
                            [ae_g_optim],
                            feed_dict={
                                X: X_train_batch,
                                threshold: upper_threshold,
                                Y: Y_train_batch,
                                lr: learning_rate_g,
                                training: True,
                                w: weight_for
                            })
                        sess.run(
                            [dis_optim],
                            feed_dict={
                                X: X_train_batch,
                                threshold: upper_threshold,
                                Y: Y_train_batch,
                                lr: learning_rate_g,
                                training: True,
                                w: weight_for
                            })
                        # print "training ended"
                        global_step += 1
                        # output some results
                        if i % show_step == 0:
                            print "epoch:", epoch, " i:", i, " train ae loss:", g_loss_c, " gan g loss:", gan_g_loss_c, " gan d loss:", gan_d_loss_c, " learning rate: ", learning_rate_g
                        if i % block_test_step == 0 and epoch % 1 == 0:
                            try:
                                X_test_batch, Y_test_batch = data.load_X_Y_voxel_test_next_batch(
                                    fix_sample=False)
                                g_loss_t, gan_g_loss_t, gan_d_loss_t, Y_test_pred, Y_test_modi, Y_test_pred_nosig = \
                                    sess.run([g_loss, ae_gan_g_loss, gan_d_loss, Y_pred, Y_pred_modi, Y_pred_nosig],
                                             feed_dict={X: X_test_batch,
                                                        threshold: upper_threshold + test_extra_threshold,
                                                        Y: Y_test_batch, training: False, w: weight_for})
                                predict_result = np.float32(Y_test_modi > 0.01)
                                predict_result = np.reshape(
                                    predict_result, [
                                        batch_size, input_shape[0],
                                        input_shape[1], input_shape[2]
                                    ])
                                print np.max(Y_test_pred)
                                print np.min(Y_test_pred)
                                # IOU
                                predict_probablity = np.float32(
                                    (Y_test_modi - 0.01) > 0)
                                predict_probablity = np.reshape(
                                    predict_probablity, [
                                        batch_size, input_shape[0],
                                        input_shape[1], input_shape[2]
                                    ])
                                accuracy = 2 * np.sum(
                                    np.abs(predict_probablity *
                                           Y_test_batch)) / np.sum(
                                               np.abs(predict_result) +
                                               np.abs(Y_test_batch))
                                print "epoch:", epoch, " global step: ", global_step, "\nIOU accuracy: ", accuracy, "\ntest ae loss:", g_loss_t, " gan g loss:", gan_g_loss_t, " gan d loss:", gan_d_loss_t
                                print "weight of foreground : ", weight_for
                                print "upper threshold of testing", (
                                    upper_threshold + test_extra_threshold)
                                train_summary = sess.run(
                                    train_merge_op,
                                    feed_dict={
                                        block_acc: accuracy,
                                        X: X_test_batch,
                                        threshold:
                                        upper_threshold + test_extra_threshold,
                                        Y: Y_test_batch,
                                        training: False,
                                        w: weight_for
                                    })
                                sum_writer_train.add_summary(
                                    train_summary, global_step=global_step)
                            except Exception, e:
                                print e
                        #### model saving
                        if i % model_save_step == 0 and epoch % 1 == 0:
                            saver.save(sess,
                                       save_path=self.train_models_dir +
                                       'model.cptk')
                            print "epoch:", epoch, " i:", i, "regular model saved!"
                else:
                    print "bad data , next epoch", epoch
Ejemplo n.º 5
0
    def train(self, configure):
        # data
        data = tools.Data(configure, epoch_walked / re_example_epoch)
        # network
        X = tf.placeholder(
            shape=[batch_size, input_shape[0], input_shape[1], input_shape[2]],
            dtype=tf.float32)
        Y = tf.placeholder(shape=[
            batch_size, output_shape[0], output_shape[1], output_shape[2]
        ],
                           dtype=tf.float32)
        lr = tf.placeholder(tf.float32)
        training = tf.placeholder(tf.bool)
        threshold = tf.placeholder(tf.float32)
        with tf.variable_scope('segmenter'):
            Y_pred, Y_pred_modi, Y_pred_nosig = self.ae_u(
                X, training, batch_size, threshold)

        # loss function
        Y_ = tf.reshape(Y, shape=[batch_size, -1])
        Y_pred_modi_ = tf.reshape(Y_pred_modi, shape=[batch_size, -1])
        w = tf.placeholder(tf.float32)  # foreground weight
        cross_loss = tf.reduce_mean(
            -tf.reduce_mean(w * Y_ * tf.log(Y_pred_modi_ + 1e-8),
                            reduction_indices=[1]) -
            tf.reduce_mean((1 - w) *
                           (1 - Y_) * tf.log(1 - Y_pred_modi_ + 1e-8),
                           reduction_indices=[1]))
        loss_sum = tf.summary.scalar("cross entropy", cross_loss)

        # trainers
        optim = tf.train.AdamOptimizer(learning_rate=lr,
                                       beta1=0.9,
                                       beta2=0.999,
                                       epsilon=1e-8).minimize(cross_loss)

        # accuracy
        block_acc = tf.placeholder(tf.float32)
        total_acc = tf.placeholder(tf.float32)
        train_sum = tf.summary.scalar("train_block_accuracy", block_acc)
        test_sum = tf.summary.scalar("total_test_accuracy", total_acc)
        train_merge_op = tf.summary.merge([train_sum, loss_sum])
        test_merge_op = tf.summary.merge([test_sum])

        saver = tf.train.Saver(max_to_keep=1)
        # config = tf.ConfigProto(allow_soft_placement=True)
        # config.gpu_options.visible_device_list = GPU0

        with tf.Session() as sess:
            # define tensorboard writer
            sum_writer_train = tf.summary.FileWriter(self.train_sum_dir,
                                                     sess.graph)
            sum_write_test = tf.summary.FileWriter(self.test_sum_dir,
                                                   sess.graph)
            # load model data if pre-trained
            sess.run(
                tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer()))
            if os.path.isfile(self.train_models_dir +
                              'model.cptk.data-00000-of-00001'):
                print "restoring saved model"
                saver.restore(sess, self.train_models_dir + 'model.cptk')
            learning_rate_g = ori_lr * pow(power, (epoch_walked / decay_step))
            # start training loop
            global_step = step_walked
            for epoch in range(epoch_walked, MAX_EPOCH):
                if epoch % re_example_epoch == 0 and epoch > 0:
                    del data
                    gc.collect()
                    data = tools.Data(configure, epoch / re_example_epoch)
                train_amount = len(data.train_numbers)
                test_amount = len(data.test_numbers)
                if train_amount >= test_amount and train_amount > 0 and test_amount > 0 and data.total_train_batch_num > 0 and data.total_test_seq_batch > 0:
                    # actual foreground weight
                    weight_for = 0.5 + (1 - 1.0 * epoch / MAX_EPOCH) * 0.35
                    if epoch % total_test_epoch == 0 and epoch > 0:
                        print '********************** FULL TESTING ********************************'
                        time_begin = time.time()
                        origin_dir = read_dicoms(test_dir + "original1")
                        mask_dir = test_dir + "artery"
                        test_batch_size = batch_size
                        # test_data = tools.Test_data(dicom_dir,input_shape)
                        test_data = tools.Test_data(origin_dir, input_shape,
                                                    'vtk_data')
                        test_data.organize_blocks()
                        test_mask = read_dicoms(mask_dir)
                        array_mask = ST.GetArrayFromImage(test_mask)
                        array_mask = np.transpose(array_mask, (2, 1, 0))
                        print "mask shape: ", np.shape(array_mask)
                        block_numbers = test_data.blocks.keys()
                        for i in range(0, len(block_numbers), test_batch_size):
                            batch_numbers = []
                            if i + test_batch_size < len(block_numbers):
                                temp_input = np.zeros([
                                    test_batch_size, input_shape[0],
                                    input_shape[1], input_shape[2]
                                ])
                                for j in range(test_batch_size):
                                    temp_num = block_numbers[i + j]
                                    temp_block = test_data.blocks[temp_num]
                                    batch_numbers.append(temp_num)
                                    block_array = temp_block.load_data()
                                    block_shape = np.shape(block_array)
                                    temp_input[j, 0:block_shape[0],
                                               0:block_shape[1],
                                               0:block_shape[2]] += block_array
                                Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run(
                                    [Y_pred, Y_pred_modi, Y_pred_nosig],
                                    feed_dict={
                                        X:
                                        temp_input,
                                        training:
                                        False,
                                        w:
                                        weight_for,
                                        threshold:
                                        upper_threshold + test_extra_threshold
                                    })
                                for j in range(test_batch_size):
                                    test_data.upload_result(
                                        batch_numbers[j],
                                        Y_temp_modi[j, :, :, :])
                            else:
                                temp_batch_size = len(block_numbers) - i
                                temp_input = np.zeros([
                                    temp_batch_size, input_shape[0],
                                    input_shape[1], input_shape[2]
                                ])
                                for j in range(temp_batch_size):
                                    temp_num = block_numbers[i + j]
                                    temp_block = test_data.blocks[temp_num]
                                    batch_numbers.append(temp_num)
                                    block_array = temp_block.load_data()
                                    block_shape = np.shape(block_array)
                                    temp_input[j, 0:block_shape[0],
                                               0:block_shape[1],
                                               0:block_shape[2]] += block_array
                                X_temp = tf.placeholder(shape=[
                                    temp_batch_size, input_shape[0],
                                    input_shape[1], input_shape[2]
                                ],
                                                        dtype=tf.float32)
                                with tf.variable_scope('segmenter',
                                                       reuse=True):
                                    Y_pred_temp, Y_pred_modi_temp, Y_pred_nosig_temp = self.ae_u(
                                        X_temp, training, temp_batch_size,
                                        threshold)
                                Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run(
                                    [
                                        Y_pred_temp, Y_pred_modi_temp,
                                        Y_pred_nosig_temp
                                    ],
                                    feed_dict={
                                        X_temp:
                                        temp_input,
                                        training:
                                        False,
                                        w:
                                        weight_for,
                                        threshold:
                                        upper_threshold + test_extra_threshold
                                    })
                                for j in range(temp_batch_size):
                                    test_data.upload_result(
                                        batch_numbers[j],
                                        Y_temp_modi[j, :, :, :])
                        test_result_array = test_data.get_result()
                        print "result shape: ", np.shape(test_result_array)
                        to_be_transformed = self.post_process(
                            test_result_array)
                        if epoch % output_epoch == 0:
                            self.output_img(to_be_transformed, test_data.space,
                                            epoch)
                        if epoch == 0:
                            mask_img = ST.GetImageFromArray(
                                np.transpose(array_mask, [2, 1, 0]))
                            mask_img.SetSpacing(test_data.space)
                            ST.WriteImage(mask_img,
                                          './test_result/test_mask.vtk')
                        test_IOU = 2 * np.sum(
                            to_be_transformed * array_mask) / (
                                np.sum(to_be_transformed) + np.sum(array_mask))
                        test_summary = sess.run(
                            test_merge_op, feed_dict={total_acc: test_IOU})
                        sum_write_test.add_summary(test_summary,
                                                   global_step=epoch)
                        print "IOU accuracy: ", test_IOU
                        time_end = time.time()
                        print '******************** time of full testing: ' + str(
                            time_end - time_begin) + 's ********************'
                    data.shuffle_X_Y_pairs()
                    total_train_batch_num = data.total_train_batch_num
                    print "total_train_batch_num:", total_train_batch_num
                    for i in range(total_train_batch_num):
                        X_train_batch, Y_train_batch = data.load_X_Y_voxel_train_next_batch(
                        )
                        # calculate loss value
                        # print "calculate begin"
                        loss_c = sess.run(
                            [cross_loss],
                            feed_dict={
                                X: X_train_batch,
                                Y: Y_train_batch,
                                training: False,
                                w: weight_for,
                                threshold: upper_threshold
                            })
                        # print "calculate ended"
                        if epoch % decay_step == 0 and epoch > epoch_walked and i == 0:
                            learning_rate_g = learning_rate_g * power
                        sess.run(
                            [optim],
                            feed_dict={
                                X: X_train_batch,
                                threshold: upper_threshold,
                                Y: Y_train_batch,
                                lr: learning_rate_g,
                                training: True,
                                w: weight_for
                            })
                        # print "training ended"
                        global_step += 1
                        # output some results
                        if i % show_step == 0:
                            print "epoch:", epoch, " i:", i, " train loss:", loss_c, " gan g loss:", learning_rate_g
                        if i % block_test_step == 0 and epoch % 1 == 0:
                            try:
                                X_test_batch, Y_test_batch = data.load_X_Y_voxel_test_next_batch(
                                    fix_sample=False)
                                Y_test_pred, Y_test_modi, Y_test_pred_nosig ,loss_t= \
                                    sess.run([ Y_pred, Y_pred_modi, Y_pred_nosig,cross_loss],
                                             feed_dict={X: X_test_batch,
                                                        threshold: upper_threshold + test_extra_threshold,
                                                        Y: Y_test_batch, training: False, w: weight_for})
                                predict_result = np.float32(Y_test_modi > 0.01)
                                predict_result = np.reshape(
                                    predict_result, [
                                        batch_size, input_shape[0],
                                        input_shape[1], input_shape[2]
                                    ])
                                print np.max(Y_test_pred)
                                print np.min(Y_test_pred)
                                # IOU
                                predict_probablity = np.float32(
                                    (Y_test_modi - 0.01) > 0)
                                predict_probablity = np.reshape(
                                    predict_probablity, [
                                        batch_size, input_shape[0],
                                        input_shape[1], input_shape[2]
                                    ])
                                accuracy = 2 * np.sum(
                                    np.abs(predict_probablity *
                                           Y_test_batch)) / np.sum(
                                               np.abs(predict_result) +
                                               np.abs(Y_test_batch))
                                print "epoch:", epoch, " global step: ", global_step, "\nIOU accuracy: ", accuracy, "\ntest ae loss:", loss_t
                                print "weight of foreground : ", weight_for
                                print "upper threshold of testing", (
                                    upper_threshold + test_extra_threshold)
                                train_summary = sess.run(
                                    train_merge_op,
                                    feed_dict={
                                        block_acc: accuracy,
                                        X: X_test_batch,
                                        threshold:
                                        upper_threshold + test_extra_threshold,
                                        Y: Y_test_batch,
                                        training: False,
                                        w: weight_for
                                    })
                                sum_writer_train.add_summary(
                                    train_summary, global_step=global_step)
                            except Exception, e:
                                print e
                        #### model saving
                        if i % model_save_step == 0 and epoch % 1 == 0:
                            saver.save(sess,
                                       save_path=self.train_models_dir +
                                       'model.cptk')
                            print "epoch:", epoch, " i:", i, "regular model saved!"
                else:
                    print "bad data , next epoch", epoch
Ejemplo n.º 6
0
    def train(self, configure):
        # data
        data = tools.Data(configure, epoch_walked / re_example_epoch)
        # network
        X = tf.placeholder(
            shape=[batch_size, input_shape[0], input_shape[1], input_shape[2]],
            dtype=tf.float32)
        Y = tf.placeholder(shape=[
            batch_size, output_shape[0], output_shape[1], output_shape[2]
        ],
                           dtype=tf.float32)
        lr = tf.placeholder(tf.float32)
        training = tf.placeholder(tf.bool)
        threshold = tf.placeholder(tf.float32)
        with tf.variable_scope('generator'):
            Y_pred, Y_pred_modi, Y_pred_nosig = self.ae_u(
                X, training, batch_size, threshold)
        with tf.variable_scope('discriminator'):
            XY_real_pair = self.dis(X, Y, training)
        with tf.variable_scope('discriminator', reuse=True):
            XY_fake_pair = self.dis(X, Y_pred, training)

        # loss function
        # generator loss
        Y_ = tf.reshape(Y, shape=[batch_size, -1])
        Y_pred_modi_ = tf.reshape(Y_pred_modi, shape=[batch_size, -1])
        w = tf.placeholder(tf.float32)  # foreground weight
        g_loss = tf.reduce_mean(-tf.reduce_mean(
            w * Y_ * tf.log(Y_pred_modi_ + 1e-8), reduction_indices=[1]) -
                                tf.reduce_mean((1 - w) * (1 - Y_) *
                                               tf.log(1 - Y_pred_modi_ + 1e-8),
                                               reduction_indices=[1]))
        g_loss_sum = tf.summary.scalar("generator cross entropy", g_loss)
        # discriminator loss
        gan_d_loss = tf.reduce_mean(XY_fake_pair) - tf.reduce_mean(
            XY_real_pair)
        alpha = tf.random_uniform(shape=[
            batch_size, input_shape[0] * input_shape[1] * input_shape[2]
        ],
                                  minval=0.0,
                                  maxval=1.0)
        Y_pred_ = tf.reshape(Y_pred, shape=[batch_size, -1])
        differences_ = Y_pred_ - Y_
        interpolates = Y_ + alpha * differences_
        with tf.variable_scope('discriminator', reuse=True):
            XY_fake_intep = self.dis(X, interpolates, training)
        gradients = tf.gradients(XY_fake_intep, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        gradient_penalty = tf.reduce_mean((slopes - 1.0)**2)
        gan_d_loss += 10 * gradient_penalty
        gan_d_loss_sum = tf.summary.scalar("total loss of discriminator",
                                           gan_d_loss)

        # generator loss with gan loss
        gan_g_loss = -tf.reduce_mean(XY_fake_pair)
        gan_g_w = 5
        ae_w = 100 - gan_g_w
        ae_gan_g_loss = ae_w * g_loss + gan_g_w * gan_g_loss
        ae_g_loss_sum = tf.summary.scalar("total loss of generator",
                                          ae_gan_g_loss)

        # trainers
        ae_var = [
            var for var in tf.trainable_variables()
            if var.name.startswith('generator')
        ]
        dis_var = [
            var for var in tf.trainable_variables()
            if var.name.startswith('discriminator')
        ]
        ae_g_optim = tf.train.AdamOptimizer(learning_rate=lr,
                                            beta1=0.9,
                                            beta2=0.999,
                                            epsilon=1e-8).minimize(
                                                ae_gan_g_loss, var_list=ae_var)
        dis_optim = tf.train.AdamOptimizer(learning_rate=lr,
                                           beta1=0.9,
                                           beta2=0.999,
                                           epsilon=1e-8).minimize(
                                               gan_d_loss, var_list=dis_var)

        # accuracy
        block_acc = tf.placeholder(tf.float32)
        total_acc = tf.placeholder(tf.float32)
        train_sum = tf.summary.scalar("train_block_accuracy", block_acc)
        test_sum = tf.summary.scalar("total_test_accuracy", total_acc)
        train_merge_op = tf.summary.merge(
            [train_sum, ae_g_loss_sum, gan_d_loss_sum, g_loss_sum])
        test_merge_op = tf.summary.merge([test_sum])

        saver = tf.train.Saver(max_to_keep=1)
        # config = tf.ConfigProto(allow_soft_placement=True)
        # config.gpu_options.visible_device_list = GPU0

        with tf.Session() as sess:
            # define tensorboard writer
            sum_writer_train = tf.summary.FileWriter(self.train_sum_dir,
                                                     sess.graph)
            sum_write_test = tf.summary.FileWriter(self.test_sum_dir,
                                                   sess.graph)
            # load model data if pre-trained
            sess.run(
                tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer()))
            if os.path.isfile(self.train_models_dir +
                              'model.cptk.data-00000-of-00001'):
                print "restoring saved model"
                saver.restore(sess, self.train_models_dir + 'model.cptk')
            learning_rate_g = ori_lr * pow(power, (epoch_walked / decay_step))
            # start training loop
            global_step = step_walked
            for epoch in range(epoch_walked, MAX_EPOCH):
                if epoch % re_example_epoch == 0 and epoch > 0:
                    del data
                    gc.collect()
                    data = tools.Data(configure, epoch / re_example_epoch)
                train_amount = len(data.train_numbers)
                test_amount = len(data.test_numbers)
                if train_amount >= test_amount and train_amount > 0 and test_amount > 0 and data.total_train_batch_num > 0 and data.total_test_seq_batch > 0:
                    # actual foreground weight
                    weight_for = 0.5 + (1 - 1.0 * epoch / MAX_EPOCH) * 0.35
                    if epoch % total_test_epoch == 0:
                        self.full_testing(sess, X, w, threshold, test_merge_op,
                                          sum_write_test, training, weight_for,
                                          total_acc, Y_pred, Y_pred_modi,
                                          Y_pred_nosig, epoch)
                    data.shuffle_X_Y_pairs()
                    total_train_batch_num = data.total_train_batch_num
                    print "total_train_batch_num:", total_train_batch_num
                    for i in range(total_train_batch_num):
                        X_train_batch, Y_train_batch = data.load_X_Y_voxel_train_next_batch(
                        )
                        # calculate loss value
                        # print "calculate begin"
                        gan_d_loss_c, = sess.run(
                            [gan_d_loss],
                            feed_dict={
                                X: X_train_batch,
                                Y: Y_train_batch,
                                training: False,
                                w: weight_for,
                                threshold: upper_threshold
                            })
                        g_loss_c, gan_g_loss_c = sess.run(
                            [g_loss, ae_gan_g_loss],
                            feed_dict={
                                X: X_train_batch,
                                Y: Y_train_batch,
                                training: False,
                                w: weight_for,
                                threshold: upper_threshold
                            })
                        # print "calculate ended"
                        if epoch % decay_step == 0 and epoch > epoch_walked and i == 0:
                            learning_rate_g = learning_rate_g * power
                        sess.run(
                            [ae_g_optim],
                            feed_dict={
                                X: X_train_batch,
                                threshold: upper_threshold,
                                Y: Y_train_batch,
                                lr: learning_rate_g,
                                training: True,
                                w: weight_for
                            })
                        sess.run(
                            [dis_optim],
                            feed_dict={
                                X: X_train_batch,
                                threshold: upper_threshold,
                                Y: Y_train_batch,
                                lr: learning_rate_g,
                                training: True,
                                w: weight_for
                            })
                        # print "training ended"
                        global_step += 1
                        # output some results
                        if i % show_step == 0:
                            print "epoch:", epoch, " i:", i, " train ae loss:", g_loss_c, " gan g loss:", gan_g_loss_c, " gan d loss:", gan_d_loss_c, " learning rate: ", learning_rate_g
                        if i % block_test_step == 0 and epoch % 1 == 0:
                            try:
                                X_test_batch, Y_test_batch = data.load_X_Y_voxel_test_next_batch(
                                    fix_sample=False)
                                g_loss_t, gan_g_loss_t, gan_d_loss_t, Y_test_pred, Y_test_modi, Y_test_pred_nosig = \
                                    sess.run([g_loss, ae_gan_g_loss, gan_d_loss, Y_pred, Y_pred_modi, Y_pred_nosig],
                                             feed_dict={X: X_test_batch,
                                                        threshold: upper_threshold + test_extra_threshold,
                                                        Y: Y_test_batch, training: False, w: weight_for})
                                predict_result = np.float32(Y_test_modi > 0.01)
                                predict_result = np.reshape(
                                    predict_result, [
                                        batch_size, input_shape[0],
                                        input_shape[1], input_shape[2]
                                    ])
                                print np.max(Y_test_pred)
                                print np.min(Y_test_pred)
                                # IOU
                                predict_probablity = np.float32(
                                    (Y_test_modi - 0.01) > 0)
                                predict_probablity = np.reshape(
                                    predict_probablity, [
                                        batch_size, input_shape[0],
                                        input_shape[1], input_shape[2]
                                    ])
                                accuracy = 2 * np.sum(
                                    np.abs(predict_probablity *
                                           Y_test_batch)) / np.sum(
                                               np.abs(predict_result) +
                                               np.abs(Y_test_batch))
                                print "epoch:", epoch, " global step: ", global_step, "\nIOU accuracy: ", accuracy, "\ntest ae loss:", g_loss_t, " gan g loss:", gan_g_loss_t, " gan d loss:", gan_d_loss_t
                                print "weight of foreground : ", weight_for
                                print "upper threshold of testing", (
                                    upper_threshold + test_extra_threshold)
                                train_summary = sess.run(
                                    train_merge_op,
                                    feed_dict={
                                        block_acc: accuracy,
                                        X: X_test_batch,
                                        threshold:
                                        upper_threshold + test_extra_threshold,
                                        Y: Y_test_batch,
                                        training: False,
                                        w: weight_for
                                    })
                                sum_writer_train.add_summary(
                                    train_summary, global_step=global_step)
                            except Exception, e:
                                print e
                        #### model saving
                        if i % model_save_step == 0 and epoch % 1 == 0:
                            saver.save(sess,
                                       save_path=self.train_models_dir +
                                       'model.cptk')
                            print "epoch:", epoch, " i:", i, "regular model saved!"
                else:
                    print "bad data , next epoch", epoch
Ejemplo n.º 7
0
import tools as t

if __name__ == '__main__':
    ds_v = t.Data('../data/sets/r1/', 'v.npy')
    ds_v.load_dataset()
    ds_v.prepare_dataset()
    # ds_v.prepare_dataset()

    # t.display_data([ds_v], 0, 5000)

    # ds_p = t.Data('../data/sets/r0/', 'p1.npy')
    # ds_t = t.Data('../data/sets/r0/', 't1.npy')
    #
    # ds_p.load_dataset()
    # ds_p.prepare_dataset(type='sfa-p')
    #
    # ds_t.load_dataset()
    # ds_t.prepare_dataset()

    cut_position = 0
    time_series_length = 500
    time_series_count = 1
    row_length = time_series_length * time_series_count

    sfa_p = t.SFAAlg(ds_v.pd)

    sfa_p.transform(cut_position, row_length, time_series_length,
                    time_series_count)

    # print(sfa_p.transformed.todense())
Ejemplo n.º 8
0
import tools as t
import classifier as c
import configuration as conf
import analitics as a
import numpy as np

if __name__ == '__main__':
    normal = t.Data(conf.path, conf.data_set_type)
    normal.add_dataset_from_file('r2/normal/', 'normal', 'csv')

    timeout = t.Data(conf.path, conf.data_set_type)
    timeout.add_dataset_from_file('r2/normal/', 'timeout', 'csv')

    ds_sizes = 5000
    ds_pos = 0
    ds_power = []
    ds_power = np.append(
        ds_power,
        t.prepare_dataset(normal.ods['Power'], ds_pos, ds_sizes, 'cut'))
    ds_power = np.append(
        ds_power,
        t.prepare_dataset(timeout.ods['Power'], ds_pos, ds_sizes, 'cut'))

    ds_temperature = []
    ds_temperature = np.append(
        ds_temperature,
        t.prepare_dataset(normal.ods['Temp'], ds_pos, ds_sizes, 'diff'))
    ds_temperature = np.append(
        ds_temperature,
        t.prepare_dataset(timeout.ods['Temp'], ds_pos, ds_sizes, 'diff'))
Ejemplo n.º 9
0
    def execute(self):
        """
        Make the plot!
        return the Figure object to the user (they can edit it if they please)
        """
        fig = plt.figure()
        gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.0)
        self.ax1 = fig.add_subplot(gs[0])
        self.ax2 = fig.add_subplot(gs[1], sharex=self.ax1)
        plt.setp(self.ax1.get_xticklabels(), visible=False)

        # organize data for plotting
        data2plot = None  # data points (only support one source of data in data/mc plot)
        bckg2plot = []  # backgrounds
        signal2plot = []  # signal distribution

        for e in self.data2plot:
            d2p = self.data2plot[e]
            sample_type = d2p.draw_type

            if sample_type == 'data': data2plot = d2p
            elif sample_type == 'background': bckg2plot.append(d2p)
            elif sample_type == 'signal': signal2plot.append(d2p)

        ##  Data points
        if data2plot is not None:
            if self.asimov:
                data2plot = self.make_asimov(bckg2plot)

            if self.blind_data is not None:
                data2plot = self.make_blind(data2plot)

            data2plot.draw_type = 'errorbar'
            data2plot.kwargs["zorder"] = 125
            tmp_data = self.plotErrorbar(data2plot)
            data2plot = tmp_data
        elif data2plot is None:
            data2plot = PlotterData('data')
            data2plot.data = tools.Data()
            data2plot.data.content = np.array(
                [np.nan for _ in bckg2plot[0].data.content])

            if self.asimov:
                # Use the total bckg to plot 'asimov' data
                data2plot = self.make_asimov(bckg2plot)
                data2plot.draw_type = 'errorbar'
                data2plot.kwargs["zorder"] = 125
                tmp_data = self.plotErrorbar(data2plot)
                data2plot = tmp_data
        self.data2plot[data2plot.name] = data2plot  # update the dictionary

        ##  Background samples
        bottom = None  # 'bottom' for stacking histograms
        bckg_unc = None
        for n, hist2plot in enumerate(bckg2plot):
            hist2plot.draw_type = 'stepfilled'
            hist2plot.kwargs["zorder"] = 100 + n
            hist2plot.kwargs[
                "bottom"] = bottom  # stack the background contributions

            tmp_hist2plot = self.plotHistogram(
                hist2plot, uncertainty=hist2plot.uncertainty)
            bckg2plot[n] = tmp_hist2plot  # update data

            try:
                bottom += hist2plot.plotData.copy()  # modify bottom
            except:
                bottom = hist2plot.plotData.copy()

            if bckg_unc is None:
                bckg_unc = np.square(hist2plot.data.error.copy())
            else:
                bckg_unc += np.square(hist2plot.data.error.copy())

        # store the total background prediction and uncertainty
        prediction = deepcopy(bckg2plot[0])  # copy attributes from bckg2plot
        prediction.plotData = bottom.copy()  # copy data from sum of all bckgs
        prediction.data.content = bottom.copy()
        prediction.data.error = np.sqrt(bckg_unc)
        self.data2plot['total_bckg'] = prediction

        for i in bckg2plot:
            self.data2plot[i.name] = i  # update the dictionary

        ##  Signal distributions (designed for BSM, but could support a SM signal)
        for n, hist2plot in enumerate(signal2plot):
            hist2plot.draw_type = 'stepfilled' if self.stack_signal else 'step'
            hist2plot.kwargs["zorder"] = 150 + n
            hist2plot.kwargs["bottom"] = bottom if self.stack_signal else None

            tmp_hist2plot = self.plotHistogram(hist2plot,
                                               uncertainty=self.uncertainty)
            signal2plot[n] = tmp_hist2plot  # update data

            try:
                bottom += hist2plot.plotData.copy(
                )  # stack signal on top of background
            except:
                bottom = hist2plot.plotData.copy()
        for i in signal2plot:
            self.data2plot[i.name] = i  # update the dictionary

        ## ratio plot [data/mc (mc=total background)]
        self.ratio.Add(numerator=data2plot.name,
                       denominator=self.datamc_denominator)

        #  ratio plot uncertainty band (uncertainty on the prediction)
        #  there are no data points to draw, we just want the uncertainty (centered on 1.)
        self.drawPredictionUncertainty()

        self.plotRatio()

        ## Axis ticks/labels
        self.set_xaxis(self.ax2)
        self.set_yaxis()

        ## CMS label
        self.text_labels()

        ## Legend
        self.drawLegend()

        return fig