Ejemplo n.º 1
0
def main():
    args = build_parser().parse_args()
    image_size = [args.img_height, args.img_width]
    # config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 1.0
    # sess = tf.Session(config=config)
    sess = tf.Session()
    unet = Unet(input_shape=image_size,
                sess=sess,
                filter_num=args.filter_num,
                batch_norm=args.batch_norm)
    unet.build_net()
    if args.checkpoint_path:
        unet.load_weights(args.checkpoint_path)

    images, masks = read_data(args.train_dir,
                              args.train_mask_dir,
                              n_images=args.n_images,
                              image_size=image_size)
    val_images, val_masks = read_data(args.val_dir,
                                      args.val_mask_dir,
                                      n_images=args.n_images // 4,
                                      image_size=image_size)
    unet.train(images=images,
               masks=masks,
               val_images=val_images,
               val_masks=val_masks,
               epochs=args.epochs,
               batch_size=args.batch_size,
               learning_rate=args.learning_rate,
               dice_loss=args.dice_loss,
               always_save=args.always_save)
Ejemplo n.º 2
0
def main():
    args = build_parser().parse_args()
    assert args.checkpoint_path

    result_dir = args.result_dir
    checkpoint_path = args.checkpoint_path
    test_dir = args.test_dir
    n_imgs = args.n_images

    image_size = [args.img_height, args.img_width]
    sess = tf.Session()
    unet = Unet(input_shape=image_size,
                sess=sess,
                filter_num=args.filter_num,
                batch_norm=args.batch_norm)
    unet.build_net(is_train=False)
    unet.load_weights(checkpoint_path)
    img_names = os.listdir(test_dir)
    img_names.sort()
    mask_names = None
    total_dice = None
    if args.mask_dir:
        mask_names = os.listdir(args.mask_dir)
        mask_names.sort()
        total_dice = 0

    if n_imgs <= 0:
        n_imgs = len(img_names)

    for i in range(n_imgs):
        print('%s %d/%d' % (img_names[i], i, n_imgs))
        img_mat = read_car_img(os.path.join(test_dir, img_names[i]),
                               image_size=image_size)
        img_mat = np.expand_dims(img_mat, axis=0)
        if mask_names:
            mask_mat = read_mask_img(os.path.join(args.mask_dir,
                                                  mask_names[i]),
                                     image_size=image_size)
            mask_mat = np.expand_dims(mask_mat, axis=0)
            res, dice = unet.predict_test(img_mat, mask_mat)
            dice = np.mean(dice)
            print('Dice coefficient:%.6f' % dice)
            total_dice += dice
        else:
            res = unet.predict(img_mat)

        if args.result_dir:
            res = res.reshape(image_size)
            misc.imsave(os.path.join(result_dir, img_names[i]), res)
    if total_dice:
        print('Average Dice coefficient:%.6f' % (total_dice / n_imgs))