def main(): args = build_parser().parse_args() image_size = [args.img_height, args.img_width] # config = tf.ConfigProto() # config.gpu_options.per_process_gpu_memory_fraction = 1.0 # sess = tf.Session(config=config) sess = tf.Session() unet = Unet(input_shape=image_size, sess=sess, filter_num=args.filter_num, batch_norm=args.batch_norm) unet.build_net() if args.checkpoint_path: unet.load_weights(args.checkpoint_path) images, masks = read_data(args.train_dir, args.train_mask_dir, n_images=args.n_images, image_size=image_size) val_images, val_masks = read_data(args.val_dir, args.val_mask_dir, n_images=args.n_images // 4, image_size=image_size) unet.train(images=images, masks=masks, val_images=val_images, val_masks=val_masks, epochs=args.epochs, batch_size=args.batch_size, learning_rate=args.learning_rate, dice_loss=args.dice_loss, always_save=args.always_save)
def main(): args = build_parser().parse_args() assert args.checkpoint_path result_dir = args.result_dir checkpoint_path = args.checkpoint_path test_dir = args.test_dir n_imgs = args.n_images image_size = [args.img_height, args.img_width] sess = tf.Session() unet = Unet(input_shape=image_size, sess=sess, filter_num=args.filter_num, batch_norm=args.batch_norm) unet.build_net(is_train=False) unet.load_weights(checkpoint_path) img_names = os.listdir(test_dir) img_names.sort() mask_names = None total_dice = None if args.mask_dir: mask_names = os.listdir(args.mask_dir) mask_names.sort() total_dice = 0 if n_imgs <= 0: n_imgs = len(img_names) for i in range(n_imgs): print('%s %d/%d' % (img_names[i], i, n_imgs)) img_mat = read_car_img(os.path.join(test_dir, img_names[i]), image_size=image_size) img_mat = np.expand_dims(img_mat, axis=0) if mask_names: mask_mat = read_mask_img(os.path.join(args.mask_dir, mask_names[i]), image_size=image_size) mask_mat = np.expand_dims(mask_mat, axis=0) res, dice = unet.predict_test(img_mat, mask_mat) dice = np.mean(dice) print('Dice coefficient:%.6f' % dice) total_dice += dice else: res = unet.predict(img_mat) if args.result_dir: res = res.reshape(image_size) misc.imsave(os.path.join(result_dir, img_names[i]), res) if total_dice: print('Average Dice coefficient:%.6f' % (total_dice / n_imgs))