예제 #1
0
def test():
    with tf.name_scope('test'):
        image = tf.placeholder(dtype=tf.int32, shape = [None, None, 3])
        image_shape = tf.placeholder(dtype = tf.int32, shape = [3, ])
        processed_image = segmentation_preprocessing.segmentation_preprocessing(image, None, out_shape=[256, 256],
                                                                                is_training=False)
        b_image = tf.expand_dims(processed_image, axis = 0)
        net = UNet.UNet(b_image, None, is_training=False, decoder=FLAGS.decoder)
        global_step = slim.get_or_create_global_step()

    sess_config = tf.ConfigProto(log_device_placement = False, allow_soft_placement = True)
    if FLAGS.gpu_memory_fraction < 0:
        sess_config.gpu_options.allow_growth = True
    elif FLAGS.gpu_memory_fraction > 0:
        sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction;
    
    checkpoint_dir = util.io.get_dir(FLAGS.checkpoint_path)
    logdir = util.io.join_path(checkpoint_dir, 'test', FLAGS.dataset_name + '_' +FLAGS.dataset_split_name)

    # Variables to restore: moving avg. or normal weights.
    if FLAGS.using_moving_average:
        variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay)
        variables_to_restore = variable_averages.variables_to_restore()
        variables_to_restore[global_step.op.name] = global_step
    else:
        variables_to_restore = slim.get_variables_to_restore()
    
    saver = tf.train.Saver(var_list = variables_to_restore)

    case_names = util.io.ls(FLAGS.dataset_dir)
    case_names.sort()
    
    checkpoint = FLAGS.checkpoint_path
    checkpoint_name = util.io.get_filename(str(checkpoint))
    IoUs = []
    dices = []
    with tf.Session(config = sess_config) as sess:
        saver.restore(sess, checkpoint)
        centers = sess.run(net.centers)

        for iter, case_name in enumerate(case_names):
            image_data = util.img.imread(
                glob(util.io.join_path(FLAGS.dataset_dir, case_name, 'images', '*.png'))[0], rgb=True)

            pixel_cls_scores, pixel_recovery_feature_map = sess.run(
                [net.pixel_cls_scores, net.pixel_recovery_features[-1]],
                feed_dict = {
                    image: image_data
            })
            print '%d/%d: %s'%(iter + 1, len(case_names), case_name), np.shape(pixel_cls_scores)
            pos_score = np.asarray(pixel_cls_scores > 0.5, np.uint8)[0, :, :, 1]
            pred = cv2.resize(pos_score, tuple(np.shape(image_data)[:2][::-1]), interpolation=cv2.INTER_NEAREST)
            gt = util.img.imread(util.io.join_path(FLAGS.dataset_dir, case_name, 'weakly_label_whole_mask.png'))[:, :, 1]
            intersection = np.sum(np.logical_and(gt != 0, pred != 0))
            union = np.sum(gt != 0) + np.sum(pred != 0)
            IoU = (1.0 * intersection) / (1.0 * union - 1.0 * intersection) * 100
            dice = (2.0 * intersection) / (1.0 * union) * 100
            IoUs.append(IoU)
            dices.append(dice)
            cv2.imwrite(util.io.join_path(FLAGS.pred_path, case_name + '.png'), pred)
            cv2.imwrite(util.io.join_path(FLAGS.pred_vis_path, case_name + '.png'),
                        np.asarray(pred * 200))
            assign_label = get_assign_label(centers, pixel_recovery_feature_map)[0]
            assign_label = assign_label * pos_score
            assign_label += 1
            cv2.imwrite(util.io.join_path(FLAGS.pred_assign_label_path, case_name + '.png'),
                        np.asarray(assign_label * 100, np.uint8))
    print('total mean of IoU is ', np.mean(IoUs))
    print('total mean of dice is ', np.mean(dices))
예제 #2
0
def generate_recovery_image_feature_map():
    with tf.name_scope('test'):
        image = tf.placeholder(dtype=tf.uint8, shape=[None, None, 3])
        image_shape_placeholder = tf.placeholder(tf.int32, shape=[2])
        processed_image = segmentation_preprocessing.segmentation_preprocessing(image, None,
                                                                                out_shape=[FLAGS.eval_image_width,
                                                                                           FLAGS.eval_image_height],
                                                                                is_training=False)
        b_image = tf.expand_dims(processed_image, axis=0)
        print('the decoder is ', FLAGS.decoder)
        net = UNet.UNet(b_image, None, is_training=False, decoder=FLAGS.decoder, update_center_flag=FLAGS.update_center,
                        batch_size=1)
        # print slim.get_variables_to_restore()
        global_step = slim.get_or_create_global_step()

    sess_config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
    if FLAGS.gpu_memory_fraction < 0:
        sess_config.gpu_options.allow_growth = True
    elif FLAGS.gpu_memory_fraction > 0:
        sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction

    # Variables to restore: moving avg. or normal weights.
    if FLAGS.using_moving_average:
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay)
        variables_to_restore = variable_averages.variables_to_restore()
        variables_to_restore[global_step.op.name] = global_step
    else:
        variables_to_restore = slim.get_variables_to_restore()

    saver = tf.train.Saver()

    case_names = util.io.ls(FLAGS.dataset_dir)
    case_names.sort()

    checkpoint = FLAGS.checkpoint_path
    checkpoint_name = util.io.get_filename(str(checkpoint))
    IoUs = []
    dices = []

    pixel_recovery_features = tf.image.resize_images(net.pixel_recovery_features, image_shape_placeholder)
    with tf.Session(config=sess_config) as sess:
        saver.restore(sess, checkpoint)
        centers = sess.run(net.centers)
        print('sum of centers is ', np.sum(centers))

        for iter, case_name in enumerate(case_names):
            image_data = util.img.imread(
                glob(util.io.join_path(FLAGS.dataset_dir, case_name, 'images', '*.png'))[0], rgb=True)
            mask = cv2.imread(glob(util.io.join_path(FLAGS.dataset_dir, case_name, 'whole_mask.png'))[0])[
                   :, :, 0]

            pixel_cls_scores, recovery_img, recovery_feature_map, b_image_v, global_step_v = sess.run(
                [net.pixel_cls_scores, net.pixel_recovery_value, pixel_recovery_features, b_image, global_step],
                feed_dict={
                    image: image_data,
                    image_shape_placeholder: np.shape(image_data)[:2]
                })
            print global_step_v
            print '%d / %d: %s' % (iter + 1, len(case_names), case_name), np.shape(pixel_cls_scores), np.max(
                pixel_cls_scores[:, :, :, 1]), np.min(pixel_cls_scores[:, :, :, 1]), np.shape(
                recovery_img), np.max(recovery_img), np.min(recovery_img), np.max(b_image_v), np.min(
                b_image_v), np.shape(b_image_v)
            print np.shape(recovery_feature_map), np.shape(mask)
            pred_vis_path = util.io.join_path(FLAGS.pred_vis_dir, case_name + '.png')
            pred_path = util.io.join_path(FLAGS.pred_dir, case_name + '.png')
            pos_score = np.asarray(pixel_cls_scores > 0.5, np.uint8)[0, :, :, 1]
            pred = cv2.resize(pos_score, tuple(np.shape(image_data)[:2][::-1]), interpolation=cv2.INTER_NEAREST)
            cv2.imwrite(pred_vis_path, np.asarray(pred * 200, np.uint8))
            cv2.imwrite(pred_path, np.asarray(pred, np.uint8))
            recovery_img_path = util.io.join_path(FLAGS.recovery_img_dir, case_name + '.png')
            cv2.imwrite(recovery_img_path, np.asarray(recovery_img[0] * 255, np.uint8))
            recovery_feature_map_path = util.io.join_path(FLAGS.recovery_feature_map_dir, case_name + '.npy')

            xs, ys = np.where(mask == 1)
            features = recovery_feature_map[0][xs, ys, :]
            print 'the size of feature map is ', np.shape(np.asarray(features, np.float32))

            np.save(recovery_feature_map_path, np.asarray(features, np.float32))
예제 #3
0
def create_clones(batch_queue):
    with tf.device('/cpu:0'):
        global_step = slim.create_global_step()
        learning_rate = tf.constant(FLAGS.learning_rate, name='learning_rate')
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        # optimizer = tf.train.MomentumOptimizer(learning_rate,
        #                                        momentum=FLAGS.momentum, name='Momentum')

        tf.summary.scalar('learning_rate', learning_rate)
    # place clones
    pixel_link_loss = 0  # for summary only
    gradients = []
    for clone_idx, gpu in enumerate(config.gpus):
        do_summary = clone_idx == 0  # only summary on the first clone
        reuse = clone_idx > 0
        with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
            with tf.name_scope(config.clone_scopes[clone_idx]) as clone_scope:
                with tf.device(gpu) as clone_device:
                    b_image, b_mask_image, b_liver_mask = batch_queue.dequeue()
                    # net = pixel_link_symbol.PixelLinkNetModify(b_image, b_mask_image, is_training=True)
                    # net = UNet.UNet(b_image, None, is_training=False, decoder=FLAGS.decoder)
                    if FLAGS.attention_flag:
                        print('use the UNet Attention')
                        net = UNetAttention.UNet(
                            b_image,
                            b_mask_image,
                            is_training=True,
                            decoder=FLAGS.decoder,
                            update_center_flag=FLAGS.update_center,
                            batch_size=FLAGS.batch_size,
                            update_center_strategy=FLAGS.update_center_strategy
                        )
                    else:
                        if FLAGS.center_block_flag:
                            print('use the UNet Blocks')
                            net = UNetBlocks.UNet(
                                b_image,
                                b_mask_image,
                                b_liver_mask,
                                is_training=True,
                                decoder=FLAGS.decoder,
                                update_center_flag=FLAGS.update_center,
                                batch_size=FLAGS.batch_size,
                                update_center_strategy=FLAGS.
                                update_center_strategy,
                                num_centers_k=FLAGS.num_centers_k,
                                full_annotation_flag=FLAGS.full_annotation_flag
                            )
                        else:
                            print('use the UNet')
                            net = UNet.UNet(
                                b_image,
                                b_mask_image,
                                is_training=True,
                                decoder=FLAGS.decoder,
                                update_center_flag=FLAGS.update_center,
                                batch_size=FLAGS.batch_size,
                                update_center_strategy=FLAGS.
                                update_center_strategy)
                    net.build_loss(do_summary=do_summary)
                    if FLAGS.update_center:
                        if FLAGS.update_center_strategy == 1:
                            update_center_op = net.update_centers(alpha=0.5)
                        elif FLAGS.update_center_strategy == 2:
                            # update_center_op, kernels_ring_masks = net.update_centers_V2()
                            print 'update_center_strategy is 2'
                        else:
                            print('the update_center_strategy do not support!')
                            assert False

                    # gather losses
                    losses = tf.get_collection(tf.GraphKeys.LOSSES,
                                               clone_scope)
                    print('losses are ', losses)
                    # binary cross entropy, dice, mse, center loss
                    if FLAGS.full_annotation_flag:
                        assert len(losses) == 2
                    elif FLAGS.update_center:
                        assert len(losses) == 4
                    else:
                        assert len(losses) == 3

                    total_clone_loss = tf.add_n(losses) / config.num_clones
                    pixel_link_loss += total_clone_loss

                    # gather regularization loss and add to clone_0 only
                    if clone_idx == 0:
                        regularization_loss = tf.add_n(
                            tf.get_collection(
                                tf.GraphKeys.REGULARIZATION_LOSSES))
                        total_clone_loss = total_clone_loss + regularization_loss

                    # compute clone gradients
                    clone_gradients = optimizer.compute_gradients(
                        total_clone_loss)
                    gradients.append(clone_gradients)

    tf.summary.scalar('pixel_link_loss', pixel_link_loss)
    tf.summary.scalar('regularization_loss', regularization_loss)
    # if FLAGS.update_center_strategy == 2:
    # with tf.name_scope('kernel_size'):
    #     for idx in range(len(net.kernel_sizes)):
    #         with tf.name_scope(str(net.kernel_sizes[idx])):
    #             ring_masks = kernels_ring_masks[idx]
    #             for i in range(5):
    #                 with tf.name_scope('mask_id_' + str(i)):
    #                     tf.summary.image('ring_area_mask',
    #                                      tf.expand_dims(
    #                                          tf.expand_dims(tf.cast(ring_masks[i], tf.float32),
    #                                                         axis=0),
    #                                          axis=3))
    # add all gradients together
    # note that the gradients do not need to be averaged, because the average operation has been done on loss.
    averaged_gradients = sum_gradients(gradients)
    if FLAGS.update_center:
        with tf.control_dependencies([]):
            apply_grad_op = optimizer.apply_gradients(averaged_gradients,
                                                      global_step=global_step)
    else:
        with tf.control_dependencies([]):
            apply_grad_op = optimizer.apply_gradients(averaged_gradients,
                                                      global_step=global_step)

    train_ops = [apply_grad_op]

    bn_update_op = util.tf.get_update_op()
    if bn_update_op is not None:
        train_ops.append(bn_update_op)

    # moving average
    if FLAGS.using_moving_average:
        tf.logging.info('using moving average in training, \
        with decay = %f' % (FLAGS.moving_average_decay))
        ema = tf.train.ExponentialMovingAverage(FLAGS.moving_average_decay)
        ema_op = ema.apply(tf.trainable_variables())
        with tf.control_dependencies([apply_grad_op]):  # ema after updating
            train_ops.append(tf.group(ema_op))

    train_op = control_flow_ops.with_dependencies(train_ops,
                                                  pixel_link_loss,
                                                  name='train_op')
    return train_op
예제 #4
0
def create_clones(batch_queue):        
    with tf.device('/cpu:0'):
        global_step = slim.create_global_step()
        learning_rate = tf.constant(FLAGS.learning_rate, name='learning_rate')
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        # optimizer = tf.train.MomentumOptimizer(learning_rate,
        #                                        momentum=FLAGS.momentum, name='Momentum')

        tf.summary.scalar('learning_rate', learning_rate)
    # place clones
    pixel_link_loss = 0 # for summary only
    gradients = []
    for clone_idx, gpu in enumerate(config.gpus):
        do_summary = clone_idx == 0 # only summary on the first clone
        reuse = clone_idx > 0
        with tf.variable_scope(tf.get_variable_scope(), reuse = reuse):
            with tf.name_scope(config.clone_scopes[clone_idx]) as clone_scope:
                with tf.device(gpu) as clone_device:
                    b_image, b_mask_image, b_liver_mask, b_pixel_cls_weight = batch_queue.dequeue()
                    # random resize
                    random_scale_idx = tf.random_uniform([], minval=0, maxval=len(scales), dtype=tf.int32)
                    scales_tensor = tf.convert_to_tensor(np.asarray(scales, np.int32), tf.int32)
                    random_scale = scales_tensor[random_scale_idx]
                    tf.summary.scalar('random_scale', random_scale[0])
                    b_image = tf.image.resize_images(b_image, random_scale)
                    b_mask_image = tf.image.resize_images(b_mask_image, random_scale,
                                                          tf.image.ResizeMethod.NEAREST_NEIGHBOR)
                    b_liver_mask = tf.image.resize_images(b_liver_mask, random_scale,
                                                          tf.image.ResizeMethod.NEAREST_NEIGHBOR)
                    b_pixel_cls_weight = tf.image.resize_images(b_pixel_cls_weight, random_scale)

                    # net = pixel_link_symbol.PixelLinkNetModify(b_image, b_mask_image, is_training=True)
                    # net = UNet.UNet(b_image, None, is_training=False, decoder=FLAGS.decoder)

                    net = UNet.UNet(b_image, b_mask_image, b_liver_mask, is_training=True, decoder=FLAGS.decoder,
                                    update_center_flag=FLAGS.update_center, batch_size=FLAGS.batch_size,
                                    update_center_strategy=FLAGS.update_center_strategy,
                                    num_centers_k=FLAGS.num_centers_k,
                                    full_annotation_flag=FLAGS.full_annotation_flag,
                                    output_shape_tensor=random_scale, pixel_cls_weight=b_pixel_cls_weight)
                    net.build_loss(do_summary=do_summary)
                    if FLAGS.update_center:
                        if FLAGS.update_center_strategy == 1:
                            update_center_op = net.update_centers(alpha=0.5)
                        elif FLAGS.update_center_strategy == 2:
                            print 'update_center_strategy is 2'
                        else:
                            print('the update_center_strategy do not support!')
                            assert False
                    
                    # gather losses
                    losses = tf.get_collection(tf.GraphKeys.LOSSES, clone_scope)
                    print('losses are ', losses)
                    # binary cross entropy, dice, mse, center loss
                    if FLAGS.full_annotation_flag:
                        assert len(losses) == 2
                    elif FLAGS.update_center:
                        assert len(losses) == 4
                    else:
                        assert len(losses) == 3

                    total_clone_loss = tf.add_n(losses) / config.num_clones
                    pixel_link_loss += total_clone_loss

                    # gather regularization loss and add to clone_0 only
                    if clone_idx == 0:
                        regularization_loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
                        total_clone_loss = total_clone_loss + regularization_loss
                    
                    # compute clone gradients
                    clone_gradients = optimizer.compute_gradients(total_clone_loss)
                    gradients.append(clone_gradients)
                    
    tf.summary.scalar('pixel_link_loss', pixel_link_loss)
    tf.summary.scalar('regularization_loss', regularization_loss)

    # add all gradients together
    # note that the gradients do not need to be averaged, because the average operation has been done on loss.
    averaged_gradients = sum_gradients(gradients)
    if FLAGS.update_center:
        with tf.control_dependencies([update_center_op]):
            apply_grad_op = optimizer.apply_gradients(averaged_gradients, global_step=global_step)
    else:
        with tf.control_dependencies([]):
            apply_grad_op = optimizer.apply_gradients(averaged_gradients, global_step=global_step)
    
    train_ops = [apply_grad_op]
    
    bn_update_op = util.tf.get_update_op()
    if bn_update_op is not None:
        train_ops.append(bn_update_op)
    
    # moving average
    if FLAGS.using_moving_average:
        tf.logging.info('using moving average in training, \
        with decay = %f'%(FLAGS.moving_average_decay))
        ema = tf.train.ExponentialMovingAverage(FLAGS.moving_average_decay)
        ema_op = ema.apply(tf.trainable_variables())
        with tf.control_dependencies([apply_grad_op]):# ema after updating
            train_ops.append(tf.group(ema_op))
         
    train_op = control_flow_ops.with_dependencies(train_ops, pixel_link_loss, name='train_op')
    return train_op
예제 #5
0
def create_clones(batch_queue):
    with tf.device('/cpu:0'):
        global_step = slim.create_global_step()
        learning_rate = tf.constant(FLAGS.learning_rate, name='learning_rate')
        # optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        optimizer = tf.train.MomentumOptimizer(learning_rate,
                                               momentum=FLAGS.momentum,
                                               name='Momentum')

        tf.summary.scalar('learning_rate', learning_rate)
    # place clones
    pixel_link_loss = 0  # for summary only
    gradients = []
    for clone_idx, gpu in enumerate(config.gpus):
        do_summary = clone_idx == 0  # only summary on the first clone
        reuse = clone_idx > 0
        with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
            with tf.name_scope(config.clone_scopes[clone_idx]) as clone_scope:
                with tf.device(gpu) as clone_device:

                    b_image, b_mask_image = batch_queue.dequeue()
                    if FLAGS.update_center:
                        init_center_value = load_kmeans_centers(
                            FLAGS.kmeans_path)
                        print('init_center_value is ',
                              np.shape(init_center_value))
                    else:
                        init_center_value = None
                        print('init_center_value is None')
                    net = UNet.UNet(b_image,
                                    b_mask_image,
                                    is_training=True,
                                    decoder=FLAGS.decoder,
                                    update_center_flag=FLAGS.update_center,
                                    batch_size=FLAGS.batch_size,
                                    init_center_value=init_center_value)
                    net.build_loss(do_summary=do_summary)
                    if FLAGS.update_center:
                        update_center_op = net.update_centers(alpha=0.5)

                    # gather losses
                    losses = tf.get_collection(tf.GraphKeys.LOSSES,
                                               clone_scope)

                    # binary cross entropy, dice, mse, center loss
                    if FLAGS.update_center:
                        assert len(losses) == 4
                    else:
                        assert len(losses) == 3

                    total_clone_loss = tf.add_n(losses) / config.num_clones
                    pixel_link_loss += total_clone_loss

                    # gather regularization loss and add to clone_0 only
                    if clone_idx == 0:
                        regularization_loss = tf.add_n(
                            tf.get_collection(
                                tf.GraphKeys.REGULARIZATION_LOSSES))
                        total_clone_loss = total_clone_loss + regularization_loss

                    # compute clone gradients
                    clone_gradients = optimizer.compute_gradients(
                        total_clone_loss)
                    gradients.append(clone_gradients)

    tf.summary.scalar('pixel_link_loss', pixel_link_loss)
    tf.summary.scalar('regularization_loss', regularization_loss)

    # add all gradients together
    # note that the gradients do not need to be averaged, because the average operation has been done on loss.
    averaged_gradients = sum_gradients(gradients)
    if FLAGS.update_center:
        with tf.control_dependencies([update_center_op]):
            apply_grad_op = optimizer.apply_gradients(averaged_gradients,
                                                      global_step=global_step)
    else:
        with tf.control_dependencies([]):
            apply_grad_op = optimizer.apply_gradients(averaged_gradients,
                                                      global_step=global_step)

    train_ops = [apply_grad_op]

    bn_update_op = util.tf.get_update_op()
    if bn_update_op is not None:
        train_ops.append(bn_update_op)

    # moving average
    if FLAGS.using_moving_average:
        tf.logging.info('using moving average in training, \
        with decay = %f' % (FLAGS.moving_average_decay))
        ema = tf.train.ExponentialMovingAverage(FLAGS.moving_average_decay)
        ema_op = ema.apply(tf.trainable_variables())
        with tf.control_dependencies([apply_grad_op]):  # ema after updating
            train_ops.append(tf.group(ema_op))

    train_op = control_flow_ops.with_dependencies(train_ops,
                                                  pixel_link_loss,
                                                  name='train_op')
    return train_op