def main(_):

    config = process_config(FLAGS.config_path)
    print(config)

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        deploy_config = deploy.DeploymentConfig(num_clones=1)

        global_step = tf.Variable(0, trainable=False, name='global_step')

        # select model and build net
        net = tdr2n2.Unet(config)

        # create batch dataset
        with tf.device(deploy_config.inputs_device()):
            data = DataGenerator(config.input)

            x_test, y_test = data.get_eval_data()
            x_test = tf.expand_dims(x_test, -1)
            x_test.set_shape([
                None, config.input.img_out_shape[0],
                config.input.img_out_shape[1], config.input.img_out_shape[2]
            ])
            y_test.set_shape([
                None, config.input.mask_out_shape[0],
                config.input.mask_out_shape[1]
            ])
            y_test = tf.cast(y_test, tf.int32)
            y_test_hot = tf.one_hot(y_test,
                                    depth=config.network.num_classes,
                                    axis=-1)

        f_score, end_points = net.net(x_test)
        f_score_img = tf.expand_dims(
            tf.cast(tf.argmax(f_score, axis=-1), tf.float32) * 50., -1)
        y_test_img = tf.expand_dims(
            tf.cast(tf.argmax(y_test_hot, axis=-1), tf.float32) * 50., -1)

        ## add precision and recall
        f_score = tf.cast(tf.argmax(f_score, -1), tf.int32)
        #f_score = tf.image.resize_bilinear(f_score, (config.input.img_out_shape[0]))
        f_score = tf.one_hot(f_score,
                             depth=config.network.num_classes,
                             axis=-1)
        pred = tf.reduce_sum(f_score * y_test_hot, axis=(0, 1, 2))
        all_pred = tf.reduce_sum(f_score, axis=(0, 1, 2)) + 1e-5
        all_true = tf.reduce_sum(y_test_hot, axis=(0, 1, 2)) + 1e-5

        # Variables to restore: moving avg. or normal weights.
        if config.train.moving_average_decay:
            variable_averages = tf.train.ExponentialMovingAverage(
                config.train.moving_average_decay, global_step)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_model_variables())
            variables_to_restore[global_step.op.name] = global_step
        else:
            variables_to_restore = slim.get_variables_to_restore()

        saver = None
        if variables_to_restore is not None:
            saver = tf_saver.Saver(variables_to_restore)

        # =================================================================== #
        # Evaluation loop.
        # =================================================================== #
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=config.deploy.gpu_memory_fraction)
        configproto = tf.ConfigProto(
            gpu_options=gpu_options,
            log_device_placement=False,
            allow_soft_placement=True,
        )

        merged = tf.summary.merge_all()
        sum_writer = tf.summary.FileWriter(logdir=config.summary.test_dir)

        for checkpoint_path in evaluation.checkpoints_iterator(
                config.finetune.eval_checkpoint_dir):
            with tf.Session(config=configproto) as session:
                session.run(tf.global_variables_initializer())
                session.run(data.get_iterator(is_train=False).initializer)
                saver.restore(session, checkpoint_path)

                logging.info('Starting evaluation at ' +
                             time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
                k = 1
                tp = []
                tp_fp = []
                tp_fn = []
                imgs = []
                while True:
                    try:
                        pred_, all_pred_, all_true_, pred_img, true_img, g_step = session.run(
                            [
                                pred, all_pred, all_true, f_score_img,
                                y_test_img, global_step
                            ])
                        tp.append(np.expand_dims(pred_, 0))
                        tp_fp.append(np.expand_dims(all_true_, 0))
                        tp_fn.append(np.expand_dims(all_pred_, 0))
                        #img = util.merge_pics(pred_img, true_img)

                        print("deal with {} images".format(
                            k * config.input.batch_size))
                        k += 1
                    except tf.errors.OutOfRangeError:
                        tp_ = np.sum(np.concatenate(tp, 0), 0)
                        tp_fn_ = np.sum(np.concatenate(tp_fn, 0), 0)
                        tp_fp_ = np.sum(np.concatenate(tp_fp, 0), 0)
                        precison = tp_ / tp_fp_
                        recall = tp_ / tp_fn_
                        dice = 2 * tp_ / (tp_fp_ + tp_fn_)

                        print(precison)
                        print(recall)
                        print(dice)
                        summary = tf.Summary()
                        for i in range(recall.shape[0]):
                            summary.value.add(
                                tag='evaluation/{}th_class_precision'.format(
                                    i),
                                simple_value=precison[i])
                            summary.value.add(
                                tag='evaluation/{}th_class_recall'.format(i),
                                simple_value=recall[i])
                            summary.value.add(
                                tag='evaluation/{}th_class_dice'.format(i),
                                simple_value=dice[i])
                        sum_writer.add_summary(summary, g_step)

                        break
                logging.info('Finished evaluation at ' +
                             time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
def main(_):
    # capture the config path from the run arguments
    # then process the json configration file
    config = process_config(FLAGS.config_path)
    print(config)

    tf.logging.set_verbosity(tf.logging.DEBUG)

    with tf.Graph().as_default():
        ######################
        # Config model_deploy#
        ######################
        deploy_config = deploy.DeploymentConfig(
            num_clones=config.deploy.num_clone)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = tf.Variable(0, trainable=False, name='global_step')

        # select model and build net
        net = u_net.Unet(config)

        # create batch dataset
        with tf.device(deploy_config.inputs_device()):
            data = DataGenerator(config.input)
            x_train, y_train = data.get_train_data()
            x_train = tf.expand_dims(x_train, -1)
            x_train.set_shape([
                None, config.input.img_out_shape[0],
                config.input.img_out_shape[1], config.input.img_out_shape[2]
            ])
            y_train.set_shape([
                None, config.input.mask_out_shape[0],
                config.input.mask_out_shape[1]
            ])
            y_train = tf.cast(y_train, tf.int32)
            y_train_hot = tf.one_hot(y_train,
                                     depth=config.network.num_classes,
                                     axis=-1)

            batch_queue = [x_train, y_train_hot]

        # =================================================================== #
        # Define the model running on every GPU.
        # =================================================================== #
        def clone_fn(batch_queue):
            x_train, y_train_hot = batch_queue
            print(x_train)
            print(y_train_hot)
            f_score, end_points = net.net(x_train)
            # Add loss function.
            net.loss(f_score, y_train_hot, type=config.network.loss_type)

            return f_score, end_points, x_train, y_train_hot

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = deploy.create_clones(deploy_config, clone_fn, [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        for loss in tf.get_collection('EXTRA_LOSSES', first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))

        f_score, _, x_train, y_train_hot = clones[0].outputs
        f_score_img = tf.expand_dims(
            tf.cast(tf.argmax(f_score, axis=-1), tf.float32), -1)
        y_train_img = tf.argmax(y_train_hot, axis=-1)
        summaries.add(tf.summary.image("Images/Original_image", x_train, 2))
        summaries.add(
            tf.summary.image(
                "Images/Ground_truth",
                tf.expand_dims(tf.cast(y_train_img, tf.float32), -1), 2))
        summaries.add(tf.summary.image("Images/Predict_", f_score_img, 2))

        ## add precision and recall
        f_score = tf.cast(tf.argmax(f_score, -1), tf.int32)
        f_score = tf.one_hot(f_score,
                             depth=config.network.num_classes,
                             axis=-1)
        pred = tf.reduce_sum(f_score * y_train_hot, axis=(0, 1, 2))
        all_pred = tf.reduce_sum(f_score, axis=(0, 1, 2)) + 1e-5
        all_true = tf.reduce_sum(y_train_hot, axis=(0, 1, 2)) + 1e-5
        recall = pred / all_pred
        prec = pred / all_true
        dice = pred * 2 / (all_true + all_pred)
        with tf.variable_scope('evaluation'):
            for i in range(config.network.num_classes):
                summaries.add(
                    tf.summary.scalar('{}th_class_precision'.format(i),
                                      prec[i]))
                summaries.add(
                    tf.summary.scalar('{}th_class_recall'.format(i),
                                      recall[i]))
                summaries.add(
                    tf.summary.scalar('{}th_class_dice'.format(i), dice[i]))

        #################################
        # Configure the moving averages #
        #################################
        if config.train.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                config.train.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf_utils.configure_learning_rate(
                config, global_step)
            optimizer = tf_utils.configure_optimizer(config.train,
                                                     learning_rate)

        if config.train.moving_average_decay:
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = tf_utils.get_variables_to_train(config.finetune)

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        if config.train.clip_gradient_norm > 0:
            with ops.name_scope('clip_grads'):
                clones_gradients = slim.learning.clip_gradient_norms(
                    clones_gradients, config.train.clip_gradient_norm)
        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # train_tensor = slim.learning.create_train_op(total_loss, optimizer, gradient_multipliers=gradient_multipliers)
        summaries.add(tf.summary.scalar('learning_rate', learning_rate))
        summaries.add(tf.summary.scalar('total_loss', total_loss))
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=config.deploy.gpu_memory_fraction)
        configproto = tf.ConfigProto(
            gpu_options=gpu_options,
            log_device_placement=False,
            allow_soft_placement=True,
        )

        saver = tf.train.Saver(max_to_keep=100)

        scaffold = tf.train.Scaffold(
            init_op=None,
            init_feed_dict=None,
            init_fn=tf_utils.get_init_fn(config),
            ready_op=None,
            ready_for_local_init_op=None,
            local_init_op=[data.get_iterator().initializer],
            summary_op=summary_op,
            saver=saver,
            copy_from_scaffold=None)

        ckpt_hook = tf.train.CheckpointSaverHook(
            checkpoint_dir=config.summary.train_dir,
            save_secs=config.summary.save_checkpoint_secs,
            save_steps=config.summary.save_checkpoint_steps,
            saver=None,
            checkpoint_basename='model.ckpt',
            scaffold=scaffold,
            listeners=None)
        sum_writer = tf.summary.FileWriter(logdir=config.summary.train_dir)
        sum_hook = tf.train.SummarySaverHook(
            save_steps=None,
            save_secs=config.summary.save_summaries_secs,
            output_dir=config.summary.train_dir,
            summary_writer=sum_writer,
            scaffold=None,
            summary_op=summary_op,
        )

        with tf.train.MonitoredTrainingSession(
                master='',
                is_chief=True,
                checkpoint_dir=config.summary.train_dir,
                scaffold=scaffold,
                hooks=[ckpt_hook, sum_hook],
                save_checkpoint_secs=None,
                save_summaries_steps=None,
                save_summaries_secs=None,
                config=configproto,
                log_step_count_steps=config.summary.log_every_n_steps) as sess:
            while not sess.should_stop():
                _, loss, g_step = sess.run(
                    [train_tensor, total_loss, global_step])
                print("{} step loss is {}".format(g_step, loss))
Exemple #3
0
x_train, y_train = dataset_train.get_train_data()


sess = tf.train.MonitoredTrainingSession(
        master='',
        is_chief=True,
        checkpoint_dir=None,
        scaffold=None,
        hooks=None,
        chief_only_hooks=None,
        save_checkpoint_secs=600,
        save_summaries_steps=100,
        save_summaries_secs=None,
        config=None,
        stop_grace_period_secs=120,
        log_step_count_steps=100
)

step = 0
sess.run(dataset_train.get_iterator().initializer)
while not sess.should_stop():
    start_time = datetime.datetime.now()
    g_step = sess.run([global_step])
    print(g_step)
    step += 1




Exemple #4
0
    config.input.img_out_shape[2]
])
y_test.set_shape(
    [None, config.input.mask_out_shape[0], config.input.mask_out_shape[1]])

print(x_train)
#print(y_test)

scaffold = tf.train.Scaffold(
    init_op=None,
    init_feed_dict=None,
    init_fn=None,
    ready_op=None,
    ready_for_local_init_op=None,
    local_init_op=[
        dataset.get_iterator().initializer,
        dataset.get_iterator(is_train=False).initializer
    ],
    summary_op=None,
    saver=None,
    copy_from_scaffold=None)

sess = tf.train.MonitoredTrainingSession(master='',
                                         is_chief=True,
                                         checkpoint_dir=None,
                                         scaffold=scaffold,
                                         hooks=None,
                                         chief_only_hooks=None,
                                         save_checkpoint_secs=600,
                                         save_summaries_steps=100,
                                         save_summaries_secs=None,