예제 #1
0
def IOU_(y_pred, y_true):
    """Returns a (approx) IOU score

    intesection = y_pred.flatten() * y_true.flatten()
    Then, IOU = 2 * intersection / (y_pred.sum() + y_true.sum() + 1e-7) + 1e-7

    Args:
        y_pred (4-D array): (N, H, W, 1)
        y_true (4-D array): (N, H, W, 1)

    Returns:
        float: IOU score
    """
    segmetric = SegMetric(1)
    for i in range(y_pred.shape[0]):
        segmetric.add_image_pair(y_pred[i, :, :, 0], y_true[i, :, :, 0])

    return segmetric.mean_IU()
예제 #2
0
def main(flags):
    IMG_MEAN = np.zeros(3)
    image_std = [1.0, 1.0, 1.0]
    # parameters of building data set
    citylist = [
        'Norfolk', 'Arlington', 'Atlanta', 'Austin', 'Seekonk', 'NewHaven'
    ]
    image_mean_list = {
        'Norfolk': [127.07435926, 129.40160709, 128.28713284],
        'Arlington': [88.30304996, 94.97338776, 93.21268212],
        'Atlanta': [101.997014375, 108.42171833, 110.044871],
        'Austin': [97.0896012682, 102.94697026, 100.7540157],
        'Seekonk': [86.67800904, 93.31221168, 92.1328146],
        'NewHaven': [106.7092798, 111.4314, 110.74903832]
    }  # BGR mean for the training data for each city

    # set training data
    if flags.training_data == 'SP':
        IMG_MEAN = np.array(
            (121.68045527, 132.14961763, 129.30317439),
            dtype=np.float32)  # mean of solar panel data in BGR order

    elif flags.training_data in citylist:
        print("Training on {} data".format(flags.training_data))
        IMG_MEAN = image_mean_list[flags.training_data]
        # if flags.unit_std:
        #     image_std = image_std_list[flags.training_data]
    elif 'all_but' in flags.training_data:
        print("Training on all(excludes Seekonk) but {} data".format(
            flags.training_data))
        except_city_name = flags.training_data.split('_')[2]
        for cityname in citylist:
            if cityname != except_city_name and cityname != 'Seekonk':
                IMG_MEAN = IMG_MEAN + np.array(image_mean_list[cityname])
        IMG_MEAN = IMG_MEAN / 4

    elif flags.training_data == 'all':
        print("Training on data of all cities (excludes Seekonk)")
        for cityname in citylist:
            if cityname != 'Seekonk':
                IMG_MEAN = IMG_MEAN + np.array(image_mean_list[cityname])
        IMG_MEAN = IMG_MEAN / 5
    else:
        print("Wrong data option: {}".format(flags.data_option))

    # setup used GPU
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = flags.GPU

    # presetting
    input_size = (128, 128)
    tf.set_random_seed(1234)
    coord = tf.train.Coordinator()
    # img_mean = [127.07435926, 129.40160709, 128.28713284]
    with tf.name_scope("training_inputs"):
        training_reader = ImageReader(flags.training_data_list,
                                      input_size,
                                      random_scale=True,
                                      random_mirror=True,
                                      random_rotate=True,
                                      ignore_label=255,
                                      img_mean=IMG_MEAN,
                                      coord=coord)
    with tf.name_scope("validation_inputs"):
        validation_reader = ImageReader(
            flags.validation_data_list,
            input_size,
            random_scale=False,
            random_mirror=False,
            random_rotate=False,
            ignore_label=255,
            img_mean=IMG_MEAN,
            coord=coord,
        )
    X_batch_op, y_batch_op = training_reader.shuffle_dequeue(flags.batch_size)
    X_test_op, y_test_op = validation_reader.shuffle_dequeue(flags.batch_size *
                                                             2)

    train = pd.read_csv(flags.training_data_list, header=0)
    n_train = train.shape[0] + 1

    test = pd.read_csv(flags.validation_data_list, header=0)
    n_test = test.shape[0] + 1

    current_time = time.strftime("%m_%d/%H_%M")

    # tf.reset_default_graph()
    X = tf.placeholder(tf.float32, shape=[None, 128, 128, 3], name="X")
    y = tf.placeholder(tf.float32, shape=[None, 128, 128, 1], name="y")
    mode = tf.placeholder(tf.bool, name="mode")

    pred_raw = make_unet(X, mode)
    pred = tf.nn.sigmoid(pred_raw)
    tf.add_to_collection("inputs", X)
    tf.add_to_collection("inputs", mode)
    tf.add_to_collection("outputs", pred)

    tf.summary.histogram("Predicted Mask", pred)
    # tf.summary.image("Predicted Mask", pred)

    global_step = tf.Variable(0,
                              dtype=tf.int64,
                              trainable=False,
                              name='global_step')

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    learning_rate = tf.train.exponential_decay(
        flags.learning_rate,
        global_step,
        tf.cast(n_train / flags.batch_size * flags.decay_step, tf.int32),
        flags.decay_rate,
        staircase=True)

    IOU_op = IOU_(pred, y)
    cross_entropy = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=pred))
    tf.summary.scalar("loss/IOU_training", IOU_op)
    tf.summary.scalar("loss/cross_entropy_training", cross_entropy)

    learning_rate_summary = tf.summary.scalar(
        "learning_rate", learning_rate)  # summary recording learning rate

    #loss = cross_entropy
    if flags.is_loss_entropy:
        loss = cross_entropy
    else:
        loss = -IOU_op

    with tf.control_dependencies(update_ops):
        train_op = make_train_op(loss, global_step, learning_rate)
        # train_op = make_train_op(cross_entropy, global_step, learning_rate)

    summary_op = tf.summary.merge_all()

    valid_IoU = tf.placeholder(tf.float32, [])
    valid_IoU_summary_op = tf.summary.scalar("loss/IoU_validation", valid_IoU)
    valid_cross_entropy = tf.placeholder(tf.float32, [])
    valid_cross_entropy_summary_op = tf.summary.scalar(
        "loss/cross_entropy_validation", valid_cross_entropy)

    # original images for summary
    train_images = tf.placeholder(tf.uint8,
                                  shape=[None, 128, 128 * 3, 3],
                                  name="training_images")
    train_image_summary_op = tf.summary.image("Training_images_summary",
                                              train_images,
                                              max_outputs=10)
    valid_images = tf.placeholder(tf.uint8,
                                  shape=[None, 128, 128 * 3, 3],
                                  name="validation_images")
    valid_image_summary_op = tf.summary.image("Validation_images_summary",
                                              valid_images,
                                              max_outputs=10)

    # Set up TF session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:

        init = tf.global_variables_initializer()
        sess.run(init)

        saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=1)

        if os.path.exists(flags.ckdir) and tf.train.get_checkpoint_state(
                flags.ckdir):
            latest_check_point = tf.train.latest_checkpoint(flags.ckdir)
            saver.restore(sess, latest_check_point)

        # elif not os.path.exists(flags.ckdir):
        #     # try:
        #     #     os.rmdir(flags.ckdir)
        #     # except FileNotFoundError:
        #     #     pass
        #     os.mkdir(flags.ckdir)

        try:
            train_summary_writer = tf.summary.FileWriter(
                flags.ckdir, sess.graph)

            threads = tf.train.start_queue_runners(coord=coord, sess=sess)

            for epoch in range(flags.epochs):

                for step in range(0, n_train, flags.batch_size):

                    start_time = time.time()
                    X_batch, y_batch = sess.run([X_batch_op, y_batch_op])

                    _, global_step_value = sess.run([train_op, global_step],
                                                    feed_dict={
                                                        X: X_batch,
                                                        y: y_batch,
                                                        mode: True
                                                    })
                    if global_step_value % 100 == 0:
                        duration = time.time() - start_time
                        pred_train, step_iou, step_cross_entropy, step_summary, = sess.run(
                            [pred, IOU_op, cross_entropy, summary_op],
                            feed_dict={
                                X: X_batch,
                                y: y_batch,
                                mode: False
                            })
                        train_summary_writer.add_summary(
                            step_summary, global_step_value)

                        print(
                            'Epoch {:d} step {:d} \t cross entropy = {:.3f}, IOU = {:.3f} ({:.3f} sec/step)'
                            .format(epoch, global_step_value,
                                    step_cross_entropy, step_iou, duration))

                # validation every epoch
                    if global_step_value % 1000 == 0:
                        segmetric = SegMetric(1)
                        # for step in range(0, n_test, flags.batch_size):
                        X_test, y_test = sess.run([X_test_op, y_test_op])
                        pred_valid, valid_cross_entropy_value = sess.run(
                            [pred, cross_entropy],
                            feed_dict={
                                X: X_test,
                                y: y_test,
                                mode: False
                            })
                        iou_temp = myIOU(y_pred=pred_valid > 0.5,
                                         y_true=y_test,
                                         segmetric=segmetric)
                        print("Test IoU: {}  Cross_Entropy: {}".format(
                            segmetric.mean_IU(), valid_cross_entropy_value))

                        valid_IoU_summary = sess.run(
                            valid_IoU_summary_op,
                            feed_dict={valid_IoU: iou_temp})
                        train_summary_writer.add_summary(
                            valid_IoU_summary, global_step_value)
                        valid_cross_entropy_summary = sess.run(
                            valid_cross_entropy_summary_op,
                            feed_dict={
                                valid_cross_entropy: valid_cross_entropy_value
                            })
                        train_summary_writer.add_summary(
                            valid_cross_entropy_summary, global_step_value)

                        train_image_summary = sess.run(
                            train_image_summary_op,
                            feed_dict={
                                train_images:
                                image_summary(X_batch,
                                              y_batch,
                                              pred_train > 0.5,
                                              IMG_MEAN,
                                              num_classes=flags.num_classes)
                            })
                        train_summary_writer.add_summary(
                            train_image_summary, global_step_value)
                        valid_image_summary = sess.run(
                            valid_image_summary_op,
                            feed_dict={
                                valid_images:
                                image_summary(X_test,
                                              y_test,
                                              pred_valid > 0.5,
                                              IMG_MEAN,
                                              num_classes=flags.num_classes)
                            })
                        train_summary_writer.add_summary(
                            valid_image_summary, global_step_value)
                    # total_iou += step_iou * X_test.shape[0]
                    #
                    # test_summary_writer.add_summary(step_summary, (epoch + 1) * (step + 1))

                saver.save(sess,
                           "{}/model.ckpt".format(flags.ckdir),
                           global_step=global_step)

        finally:
            coord.request_stop()
            coord.join(threads)
            saver.save(sess,
                       "{}/model.ckpt".format(flags.ckdir),
                       global_step=global_step)
def main():
    # get arguments
    args = get_arguments()

    # setup used GPU
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU
    """Create the model and start the evaluation process."""

    # data reader.

    # input image
    input_img = tf.placeholder(tf.float32,
                               shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3],
                               name="input_image")
    # img = tf.image.decode_jpeg(tf.read_file(args.img_path), channels=3)
    # Convert RGB to BGR.
    img_r, img_g, img_b = tf.split(axis=3,
                                   num_or_size_splits=3,
                                   value=input_img)
    img = tf.cast(tf.concat(axis=3, values=[img_b, img_g, img_r]),
                  dtype=tf.float32)
    # Extract mean.
    img -= IMG_MEAN

    img_upscale = tf.image.resize_bilinear(
        img, [IMAGE_SIZE * args.up_scale, IMAGE_SIZE * args.up_scale])

    # Create network.
    net = DeepLabResNetModel({'data': img},
                             is_training=False,
                             num_classes=args.num_classes)

    # Which variables to load.
    restore_var = tf.global_variables()

    # Predictions.
    res5c_relu = net.layers['res5c_relu']
    fc1_voc12_c0 = net.layers['fc1_voc12_c0']
    fc1_voc12_c1 = net.layers['fc1_voc12_c1']
    fc1_voc12_c2 = net.layers['fc1_voc12_c2']
    fc1_voc12_c3 = net.layers['fc1_voc12_c3']

    raw_output = net.layers['fc1_voc12']

    raw_output_up = tf.image.resize_bilinear(raw_output, tf.shape(img)[1:3, ])
    # raw_output_up_argmax = tf.argmax(raw_output_up, dimension=3)
    # pred = tf.expand_dims(raw_output_up_argmax, dim=3)
    pmap = tf.nn.softmax(raw_output_up, name="probability_map")

    # Set up TF session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)

    # Load weights.
    loader = tf.train.Saver(var_list=restore_var)

    if os.path.isdir(args.restore_from):
        # search checkpoint at given path
        ckpt = tf.train.get_checkpoint_state(args.restore_from)
        if ckpt and ckpt.model_checkpoint_path:
            # load checkpoint file
            load(loader, sess, ckpt.model_checkpoint_path)
            print("Model restored from {}".format(ckpt.model_checkpoint_path))
        else:
            print("No model found at{}".format(args.restore_from))
    elif os.path.isfile(args.restore_from):
        # load checkpoint file
        load(loader, sess, args.restore_from)
    else:
        print("No model found at{}".format(args.restore_from))
    '''Perform validation on large images.'''
    # preds, scoremap, pmap, cnn_out, fc0, fc1, fc2, fc3 = sess.run([pred, raw_output, raw_output_up, res5c_relu, fc1_voc12_c0, fc1_voc12_c1, fc1_voc12_c2, fc1_voc12_c3], feed_dict={input_img})

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # gaussian weight kernel
    gfilter = gauss2D(shape=[IMAGE_SIZE, IMAGE_SIZE],
                      sigma=(IMAGE_SIZE - 1) / 4)

    seg_metric = SegMetric(1)

    for valid_file in valid_list:
        print("Validate image {}".format(valid_file[0:-2]))
        valid_image = misc.imread(
            os.path.join(args.img_path, valid_file.format('.png')))
        valid_truth = (misc.imread(
            os.path.join(args.img_path, valid_file.format('_truth.png'))) /
                       255).astype(np.uint8)
        image_shape = valid_truth.shape

        valid_patches = patchify(valid_image, IMAGE_SIZE, valid_stride)
        """divided patches into smaller batch for validation"""
        pred_pmap = valid_in_batch(valid_patches,
                                   sess,
                                   pmap,
                                   input_img,
                                   step=valid_batch_size)

        # pred_pmap = np.ones(valid_patches.shape[0:-1])

        print("Stiching patches")
        pred_pmap_weighted = pred_pmap * gfilter[None, :, :]
        pred_pmap_weighted_large = unpatchify(pred_pmap_weighted, image_shape,
                                              valid_stride)
        gauss_mask_large = unpatchify(
            np.ones(pred_pmap.shape) * gfilter[None, :, :], image_shape,
            valid_stride)
        pred_pmap_weighted_large_normalized = np.nan_to_num(
            pred_pmap_weighted_large / gauss_mask_large)
        pred_binary = (pred_pmap_weighted_large_normalized > 0.5).astype(
            np.uint8)

        # mean IoU
        seg_metric.add_image_pair(pred_binary, valid_truth)
        print("mean_IU: {:.4f}".format(mean_IU(pred_binary, valid_truth)))

        # print("Save validation prediction")
        misc.imsave(
            os.path.join(args.save_dir,
                         '{}_valid_pred.png'.format(valid_file[0:-2])),
            pred_binary)
        misc.imsave(
            os.path.join(args.save_dir,
                         '{}_valid_pred_255.png'.format(valid_file[0:-2])),
            pred_binary * 255)
        misc.toimage(pred_pmap_weighted_large_normalized.astype(np.float32),
                     high=1.0,
                     low=0.0,
                     cmin=0.0,
                     cmax=1.0,
                     mode='F').save(
                         os.path.join(
                             args.save_dir,
                             '{}_valid_pmap.tif'.format(valid_file[0:-2])))

        # # Plot PR curve
        # precision, recall, thresholds = precision_recall_curve(valid_truth.flatten(), pred_pmap_weighted_large_normalized.flatten(), 1)
        # plt.figure()
        # plt.plot(recall, precision, lw=2, color='navy',
        #          label='Precision-Recall curve')
        # plt.xlabel('Recall')
        # plt.ylabel('Precision')
        # plt.ylim([0.0, 1.05])
        # plt.xlim([0.0, 1.0])
        # plt.title('Precision-Recal')
        # # plt.legend(loc="lower left")
        # plt.savefig(os.path.join(args.save_dir, '{}_PR_curve.png'.format(valid_file[0:-2])))

    # msk = decode_labels(preds, num_classes=args.num_classes)
    # im = Image.fromarray(msk[0])

    # im.save(args.save_dir + 'pred.png')

    print("Overal mean IoU: {:.4f}".format(seg_metric.mean_IU()))
    print('The output file has been saved to {}'.format(args.save_dir))
예제 #4
0
def main(flags):
    IMG_MEAN = np.zeros(3)

    # parameters of building data set
    citylist = ['Norfolk', 'Arlington', 'Atlanta', 'Austin', 'Seekonk', 'NewHaven']
    image_mean_list = {'Norfolk': [127.07435926, 129.40160709, 128.28713284],
                       'Arlington': [88.30304996, 94.97338776, 93.21268212],
                       'Atlanta': [101.997014375, 108.42171833, 110.044871],
                       'Austin': [97.0896012682, 102.94697026, 100.7540157],
                       'Seekonk': [86.67800904, 93.31221168, 92.1328146],
                       'NewHaven': [106.7092798, 111.4314,
                                    110.74903832]}  # BGR mean for the training data for each city

    num_samples = {'Norfolk': 3,
                   'Arlington': 3,
                   'Atlanta': 3,
                   'Austin': 3,
                   'Seekonk': 3,
                   'NewHaven': 2}  # number of samples for each city

    # set evaluation data
    if flags.training_data == 'SP':
        IMG_MEAN = np.array((121.68045527, 132.14961763, 129.30317439),
                            dtype=np.float32)  # mean of solar panel data in BGR order
        valid_list = ['11ska625680{}', '11ska610860{}', '11ska445890{}', '11ska520695{}', '11ska355800{}',
                      '11ska370755{}',
                      '11ska385710{}', '11ska550770{}', '11ska505740{}', '11ska385800{}', '11ska655770{}',
                      '11ska385770{}',
                      '11ska610740{}', '11ska550830{}', '11ska625830{}', '11ska535740{}', '11ska520815{}',
                      '11ska595650{}',
                      '11ska475665{}', '11ska520845{}']

    elif flags.training_data in citylist:
        IMG_MEAN = image_mean_list[flags.training_data] # mean of building data in RGB order
        valid_list = ["{}_{:0>2}{{}}".format(flags.testing_data, i) for i in
                      range(1, num_samples[flags.testing_data] + 1)]


    elif 'all_but' in flags.training_data:
        except_city_name = flags.training_data.split('_')[2]
        for cityname in citylist:
            if cityname != except_city_name and cityname != 'Seekonk':
                IMG_MEAN = IMG_MEAN + np.array(image_mean_list[cityname])
        IMG_MEAN = IMG_MEAN / 4
        valid_list = ["{}_{:0>2}{{}}".format(flags.testing_data, i) for i in
                      range(1, num_samples[flags.testing_data] + 1)]

    elif flags.training_data == 'all':
        for cityname in citylist:
            if cityname != 'Seekonk':
                IMG_MEAN = IMG_MEAN + np.array(image_mean_list[cityname])
        IMG_MEAN = IMG_MEAN / 5
        valid_list = ["{}_{:0>2}{{}}".format(flags.testing_data, i) for i in
                      range(1, num_samples[flags.testing_data] + 1)]

    else:
        print("Wrong data option: {}".format(flags.data_option))

    IMG_MEAN = [IMG_MEAN[2], IMG_MEAN[1], IMG_MEAN[0]]  # convert to RGB order

    # setup used GPU
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = flags.GPU

    # presetting
    tf.set_random_seed(1234)

    # input image batch with zero mean
    image_batch = tf.placeholder(tf.float32, shape=[None, 128, 128, 3], name="image_batch")
    # Convert RGB to BGR.
    img_r, img_g, img_b = tf.split(axis=3, num_or_size_splits=3, value=image_batch)
    img_bgr = tf.cast(tf.concat(axis=3, values=[img_b, img_g, img_r]), dtype=tf.float32)

    prediction_batch = tf.placeholder(tf.float32, shape=[None, 128, 128, 1], name="prediction_batch")

    pred_raw = make_unet(img_bgr, training=False)
    pred = tf.nn.sigmoid(pred_raw)
    tf.add_to_collection("inputs", image_batch)
    tf.add_to_collection("outputs", pred)

    # Set up TF session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:

        init = tf.global_variables_initializer()
        sess.run(init)

        saver = tf.train.Saver(var_list=tf.global_variables())

        if os.path.exists(flags.restore_from) and tf.train.get_checkpoint_state(flags.restore_from):
            latest_check_point = tf.train.latest_checkpoint(flags.restore_from)
            print("Loading model: {}".format(latest_check_point))
            saver.restore(sess, latest_check_point)
        else:
            print("No model found at{}".format(flags.restore_from))
            sys.exit()

        if not os.path.exists(flags.save_dir):
            os.makedirs(flags.save_dir)

        # testing model on large images by running over patches
        gfilter = gauss2D(shape=[flags.image_size, flags.image_size], sigma=(flags.image_size - 1) / 4)
        seg_metric = SegMetric(1)
        valid_stride = int(flags.image_size / 2)

        print("Testing {} model on {} data {}{}".format(flags.training_data, flags.testing_data, "with transferring mean" if flags.is_mean_transfer else "", "with CORAL domain adaption" if flags.is_CORAL else ""))

        file = open(os.path.join(flags.restore_from, 'test_log.csv'), 'a')
        file.write("\nTest Model: {}\ntransfer_mean:{} CORAL domain adaption:{}\n".format(latest_check_point, flags.is_mean_transfer, flags.is_CORAL))

        for valid_file in valid_list:
            print("Testing image {}".format(valid_file[0:-2]))
            if flags.testing_data == 'SP':
                valid_image = misc.imread(os.path.join(flags.img_path, valid_file.format('.png')))
            else:
                valid_image = misc.imread(os.path.join(flags.img_path, flags.testing_data, valid_file.format('_RGB_1feet.png')))

            valid_truth = (misc.imread(os.path.join(flags.img_path, flags.testing_data, valid_file.format('_truth_1feet.png'))) / 255).astype(np.uint8)

            if flags.is_CORAL:
                train_image = misc.imread(os.path.join(flags.img_path, flags.training_data, '{}_01_RGB.png'.format(flags.training_data)))
                valid_image = image_adapt(valid_image, train_image, 1)

            valid_image = misc.imresize(valid_image, flags.resolution_ratio, interp='bilinear')
            valid_truth = misc.imresize(valid_truth, flags.resolution_ratio, interp='nearest')

            if flags.is_mean_transfer:
                IMG_MEAN = np.mean(valid_image, axis=(0, 1))  # Image mean of testing data

            valid_image = valid_image - IMG_MEAN  # substract mean from image

            image_shape = valid_truth.shape

            valid_patches = patchify(valid_image, flags.image_size, valid_stride)
            """divided patches into smaller batch for evaluation"""
            pred_pmap = valid_in_batch(valid_patches, sess, pred, image_batch, step=flags.batch_size)

            # pred_pmap = np.ones(valid_patches.shape[0:-1])

            print("Stiching patches")
            pred_pmap_weighted = pred_pmap * gfilter[None, :, :]
            pred_pmap_weighted_large = unpatchify(pred_pmap_weighted, image_shape, valid_stride)
            gauss_mask_large = unpatchify(np.ones(pred_pmap.shape) * gfilter[None, :, :], image_shape, valid_stride)
            pred_pmap_weighted_large_normalized = np.nan_to_num(pred_pmap_weighted_large / gauss_mask_large)
            pred_binary = (pred_pmap_weighted_large_normalized > flags.pred_threshold).astype(np.uint8)

            # mean IoU
            seg_metric.add_image_pair(pred_binary, valid_truth)
            message_temp = "{}, {:.4f}".format(valid_file[0:-2], mean_IU(pred_binary, valid_truth))
            print(message_temp)
            file.write(message_temp + '\n')

            print("Saving evaluation prediction")

            # misc.imsave(os.path.join(flags.save_dir, '{}_{}pred.png'.format(valid_file[0:-2], 'NT_' if not flags.is_mean_transfer else '')), pred_binary)
            misc.imsave(os.path.join(flags.save_dir, '{}_pred_threshold_{}{}{}.png'.format(valid_file[0:-2],flags.pred_threshold, '_TM' if flags.is_mean_transfer else '', '_CORAL' if flags.is_CORAL else '')), pred_binary * 255)
            misc.toimage(pred_pmap_weighted_large_normalized.astype(np.float32), high=1.0, low=0.0, cmin=0.0, cmax=1.0, mode='F').save(os.path.join(flags.save_dir, '{}_pred_pmap{}{}.tif'.format(valid_file[0:-2], '_TM' if flags.is_mean_transfer else '', '_CORAL' if flags.is_CORAL else '')))



        message_overall = "Overall, {:.4f}".format(seg_metric.mean_IU())
        print(message_overall)
        file.write(message_overall + '\n')
        file.close()
        print('The output file has been saved to {}'.format(flags.save_dir))

        sess.close()
예제 #5
0
def main(argv=None):
    # GPU setup
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.GPU

    # log folder
    log_dir = os.path.join(FLAGS.logs_dir,
                           "logs_batch{}/".format(FLAGS.batch_size))
    log_image_dir = log_dir + "images"

    print("Setting up dataset reader")
    # image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
    # if (FLAGS.mode == 'train') | (FLAGS.mode == 'visualize'):
    if FLAGS.mode == 'train':
        train_dataset_reader = BatchReader(FLAGS.mat_dir, TRAINING_FILES)
        train_dataset_reader.shuffle_images()
    # if (FLAGS.mode == 'train')| (FLAGS.mode == 'visualize')| (FLAGS.mode == 'test'):
    validation_dataset_reader = BatchReader(FLAGS.mat_dir, VALIDATION_FILES)

    print("Setting up Graph")
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32,
                           shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3],
                           name="input_image")
    annotation = tf.placeholder(tf.int32,
                                shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1],
                                name="annotation")

    global_step = tf.Variable(0,
                              dtype=tf.int32,
                              trainable=False,
                              name='global_step')

    pred_annotation, logits, pmap = inference(
        image, keep_probability, validation_dataset_reader.mean_image)
    tf.summary.image("input_image", image, max_outputs=2)
    tf.summary.image("ground_truth",
                     tf.cast(annotation, tf.uint8),
                     max_outputs=2)
    tf.summary.image("pred_annotation",
                     tf.cast(pred_annotation, tf.uint8),
                     max_outputs=2)
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits,
        labels=tf.squeeze(annotation, squeeze_dims=[3]),
        name="entropy")))
    tf.summary.scalar("entropy", loss)

    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)
    train_op = train(loss, trainable_var, global_step)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    # print("Setting up image reader...")
    # train_records, valid_records = scene_parsing.read_dataset(FLAGS.data_dir)
    # print(len(train_records))
    # print(len(valid_records))

    # GPU configuration to avoid that TF takes all GPU memeory
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    sess = tf.Session(config=config)

    print("Setting up Saver...")
    saver = tf.train.Saver(max_to_keep=3)
    summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(log_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored from {}".format(log_dir))

    if FLAGS.mode == "train":
        print("Start training with batch size {}, learning rate{}".format(
            FLAGS.batch_size, FLAGS.learning_rate))
        itr = int(0)
        while train_dataset_reader.epochs_completed < FLAGS.epochs:
            train_images, train_annotations = train_dataset_reader.next_batch(
                FLAGS.batch_size)
            feed_dict = {
                image: train_images,
                annotation: train_annotations,
                keep_probability: 0.85
            }

            _, itr = sess.run([train_op, global_step], feed_dict=feed_dict)

            if itr % 20 == 0:
                train_loss, summary_str = sess.run([loss, summary_op],
                                                   feed_dict=feed_dict)
                print("EPOCH %d Step: %d, Train_loss:%g" %
                      (train_dataset_reader.epochs_completed, itr, train_loss))
                summary_writer.add_summary(summary_str, itr)

            if itr % 200 == 0:
                valid_images, valid_annotations = validation_dataset_reader.get_random_batch(
                    FLAGS.batch_size)
                valid_loss = sess.run(loss,
                                      feed_dict={
                                          image: valid_images,
                                          annotation: valid_annotations,
                                          keep_probability: 1.0
                                      })
                print("%s ---> Validation_loss: %g" %
                      (datetime.datetime.now(), valid_loss))
                saver.save(sess, log_dir + "model.ckpt", itr)
                print("Checkpoint saved")
        saver.save(sess, log_dir + "model.ckpt", itr)

    elif FLAGS.mode == "visualize":
        print("visualize {} images".format(FLAGS.visualize_size))
        valid_images, valid_annotations = validation_dataset_reader.get_random_batch(
            FLAGS.visualize_size)
        # train_images, train_annotations = train_dataset_reader.next_batch(FLAGS.visualize_size)

        pred_valid, probability_map = sess.run([pred_annotation, pmap],
                                               feed_dict={
                                                   image: valid_images,
                                                   keep_probability: 1.0
                                               })
        # pred_train = sess.run(pred_annotation, feed_dict={image: train_images, annotation: train_annotations, keep_probability: 1.0})

        valid_annotations = np.squeeze(valid_annotations, axis=3)
        # train_annotations = np.squeeze(train_annotations, axis=3)
        pred_valid = np.squeeze(pred_valid, axis=3)
        # pred_train = np.squeeze(pred_train, axis=3)
        probability_map = probability_map[:, :, :, 1]
        """save images"""
        log_image_dir = log_dir + "images"
        if not os.path.isdir(log_image_dir):
            os.makedirs(log_image_dir)

        for itr in range(FLAGS.visualize_size):
            utils.save_image(valid_images[itr].astype(np.uint8),
                             log_image_dir,
                             name="inp_test" + str(itr))
            utils.save_image(valid_annotations[itr].astype(np.uint8),
                             log_image_dir,
                             name="gt_test" + str(itr))
            utils.save_image(pred_valid[itr].astype(np.uint8),
                             log_image_dir,
                             name="pred_test" + str(itr))
            utils.save_image(probability_map[itr].astype(np.double),
                             log_image_dir,
                             name="pmap_test" + str(itr))
            print("Saved image: %d" % itr)

            # for itr in range(FLAGS.visualize_size):
            #     utils.save_image(train_images[itr].astype(np.uint8), imageDir, name="inp_train" + str(itr))
            #     utils.save_image(train_annotations[itr].astype(np.uint8), imageDir, name="gt_train" + str(itr))
            #     utils.save_image(pred_train[itr].astype(np.uint8), imageDir, name="pred_train" + str(itr))
            #     print("Saved image: %d" % itr)

    elif FLAGS.mode == "validate":
        valid_stride = 20
        valid_batch_size = 1000
        valid_list = [
            '11ska595800{}', '11ska460755{}', '11ska580860{}', '11ska565845{}'
        ]

        gfilter = gauss2D(shape=[IMAGE_SIZE, IMAGE_SIZE],
                          sigma=(IMAGE_SIZE - 1) / 4)

        for valid_file in valid_list:
            print("Validate image {}".format(valid_file[0:-2]))
            valid_image = misc.imread(
                os.path.join(imgDir, valid_file.format('.png')))
            valid_annotation = misc.imread(
                os.path.join(imgDir, valid_file.format('_truth.png')))
            image_shape = valid_annotation.shape

            valid_patches = patchify(valid_image, IMAGE_SIZE, valid_stride)
            """divided patches into smaller batch for validation"""
            pred_pmap = test_in_batch(valid_patches,
                                      sess,
                                      pmap,
                                      image,
                                      keep_probability,
                                      step=valid_batch_size)

            pred_pmap_weighted = pred_pmap * gfilter[None, :, :]
            pred_weighted_rec = unpatchify(pred_pmap_weighted, image_shape,
                                           valid_stride)
            gauss_mask_rec = unpatchify(
                np.ones(pred_pmap.shape) * gfilter[None, :, :], image_shape,
                valid_stride)
            pred_weighted_normalized = np.nan_to_num(pred_weighted_rec /
                                                     gauss_mask_rec)

            print("Save validation prediction")
            utils.save_image(pred_weighted_normalized.astype(np.float32),
                             log_image_dir,
                             name="{}_valid_pred".format(valid_file[0:-2]))
            misc.toimage(pred_weighted_normalized.astype(np.float32),
                         high=1.0,
                         low=0.0,
                         cmin=0.0,
                         cmax=1.0,
                         mode='F').save(
                             os.path.join(
                                 log_image_dir,
                                 '{}_valid_pmap.tif'.format(valid_file[0:-2])))

            print("mean_IU: {}".format(
                mean_IU((pred_weighted_normalized > 0.5).astype(int),
                        valid_annotation)))

    elif FLAGS.mode == "test":
        """
        test on validation images one by one
        """
        # for itr in xrange(len(valid_records)):
        #     valid_images, valid_annotations = validation_dataset_reader.next_batch(1)
        #     pred_valid = sess.run(pred_annotation, feed_dict={image: valid_images, annotation: valid_annotations,
        #                                             keep_probability: 1.0})
        #     valid_annotations = np.squeeze(valid_annotations)
        #     pred_valid = np.squeeze(pred_valid)
        #     if itr == 0:
        #         valid_annotations_concatenate = valid_annotations
        #         pred_concatenate = pred_valid
        #     else:
        #         valid_annotations_concatenate = np.concatenate((valid_annotations_concatenate,valid_annotations),axis = 0)
        #         pred_concatenate = np.concatenate((pred_concatenate, pred_valid),axis = 0)
        #     print('test %d th image of %d validation image' %(itr+1,len(valid_records)))
        #
        # print(pixel_accuracy(valid_annotations_concatenate, pred_concatenate))
        # print(mean_accuracy(valid_annotations_concatenate, pred_concatenate))
        # print(mean_IU(valid_annotations_concatenate, pred_concatenate))
        # print(frequency_weighted_IU(valid_annotations_concatenate, pred_concatenate))
        """save prediction results on validation images"""
        print("testing validation images")
        seg_metric = SegMetric(NUM_OF_CLASSES)
        for itr in range(validation_dataset_reader.images.shape[0]):
            print("testing {}th image".format(itr + 1))
            valid_images, valid_annotations = validation_dataset_reader.next_batch(
                1)
            pred_valid = sess.run(pred_annotation,
                                  feed_dict={
                                      image: valid_images,
                                      annotation: valid_annotations,
                                      keep_probability: 1.0
                                  })
            valid_annotations = np.squeeze(valid_annotations)
            pred_valid = np.squeeze(pred_valid)
            seg_metric.add_image_pair(pred_valid, valid_annotations)
            if (itr + 1) % 1000 == 0:
                print("itr{}:".format(itr + 1))
                seg_metric.pixel_accuracy()
                seg_metric.mean_accuracy()
                seg_metric.mean_IU()
                seg_metric.frequency_weighted_IU()

        print("Final Accuracy:")
        seg_metric.pixel_accuracy()
        seg_metric.mean_accuracy()
        seg_metric.mean_IU()
        seg_metric.frequency_weighted_IU()
truthDir = os.path.expanduser('~/Documents/data/igarssTrainingAndTestingData/imageFiles')
validDir = [os.path.join(os.path.realpath('../FCN/logs'), 'logs_batch%d' % batch_size, 'images')]
learning_rates = [1e-5, 1e-4, 1e-3]
weight_decay = 0.0005
batch_sizes = [20]
for learning_rate in learning_rates:
    for batch_size in batch_sizes:
        if learning_rate == 1e-4:
            batch_size = 40
        validDir.append(
            os.path.join(os.path.realpath('../../tensorflow-deeplab-resnet/snapshots/train_with_pretrained_model'),
                         'sc_all_batchsize{}_learningRate_{:.0e}_weight_decay_{}/'.format(batch_size, learning_rate,
                                                                                          weight_decay), 'images'))
pred_binary = []

seg_metric = SegMetric(1)

meanIoUs = []
mean_performance = []
variance_performance = []
patch_size = 20
for i in range(0, 1):
    image = misc.imread(os.path.join(truthDir, imageFileName[i].format('.png')))
    truth = (misc.imread(os.path.join(truthDir, imageFileName[i].format('_truth.png'))) / 255).astype(np.uint8)
    region_center = np.nan_to_num(center_of_mass(truth, label(truth)[0], range(label(truth)[1])))
    meanIoU = np.zeros([region_center.__len__() - 1, validDir.__len__()])

    for l in range(validDir.__len__()):
        pred_binary.append(
            ((misc.imread(os.path.join(validDir[l], imageFileName[i].format('_valid_pmap.tif')))) > 0.5).astype(int))
예제 #7
0
truthDir = os.path.expanduser(
    '~/Documents/data/igarssTrainingAndTestingData/imageFiles')

learning_rates = [1e-3]
weight_decays = [0.0005]
batch_sizes = [20]
for weight_decay in weight_decays:
    for batch_size in batch_sizes:
        for learning_rate in learning_rates:
            if (batch_size != 20) | (learning_rate != 1e-5) | (weight_decay !=
                                                               0.0005):
                validDir = os.path.join(
                    os.path.realpath('./snapshots'),
                    'with_pretrained_model/sc_batchsize{}_learningRate_{:.0e}_weight_decay_{}/'
                    .format(batch_size, learning_rate, weight_decay), 'images')
                seg_metric = SegMetric(1)
                for i in range(0, num_val):
                    image = misc.imread(
                        os.path.join(truthDir,
                                     imageFileName[i].format('.png')))
                    truth = (misc.imread(
                        os.path.join(truthDir,
                                     imageFileName[i].format('_truth.png'))) /
                             255).astype(np.uint8)
                    valid_pmap = misc.imread(
                        os.path.join(
                            validDir,
                            imageFileName[i].format('_valid_float.tif')))
                    pred_binary = (valid_pmap > 0.5).astype(np.uint8)
                    seg_metric.add_image_pair(pred_binary, truth)
                print('sc_batchsize{}_learningRate_{:.0e}_weight_decay_{}'.
예제 #8
0
from sklearn.metrics import f1_score
from eval_segm import *
from seg_metric import SegMetric

imageFileName = [
    '11ska595800{}', '11ska460755{}', '11ska580860{}', '11ska565845{}'
]
num_val = imageFileName.__len__()
truthDir = os.path.expanduser(
    '~/Documents/data/igarssTrainingAndTestingData/imageFiles')

batch_size = 128
validDir = os.path.join(os.path.realpath('../FCN/logs'),
                        'logs_batch%d' % batch_size, 'images')

seg_metric = SegMetric(1)
# plt.figure("PR Curve")
for i in range(0, num_val):
    image = misc.imread(os.path.join(truthDir,
                                     imageFileName[i].format('.png')))
    truth = (misc.imread(
        os.path.join(truthDir, imageFileName[i].format('_truth.png'))) /
             255).astype(np.uint8)
    valid_pmap = misc.imread(
        os.path.join(validDir, imageFileName[i].format('_valid_pmap.tif')))
    pred_binary = (valid_pmap > 0.5).astype(np.uint8)

    # mean IoU
    seg_metric.add_image_pair(pred_binary, truth)
    print("Image {}: {:.4f}".format(imageFileName[i][0:-2],
                                    mean_IU(pred_binary, truth)))
예제 #9
0
def main():
    # get arguments
    args = get_arguments()

    IMG_MEAN = np.zeros(3)
    valid_list=[]

    # parameters of building data set
    citylist = ['Norfolk', 'Arlington', 'Atlanta', 'Austin', 'Seekonk', 'NewHaven']
    image_mean_list = {'Norfolk': [127.07435926, 129.40160709, 128.28713284],
                       'Arlington': [88.30304996, 94.97338776, 93.21268212],
                       'Atlanta': [101.997014375, 108.42171833, 110.044871],
                       'Austin': [97.0896012682, 102.94697026, 100.7540157],
                       'Seekonk': [86.67800904, 93.31221168, 92.1328146],
                       'NewHaven': [106.7092798, 111.4314, 110.74903832]} # BGR mean for the training data for each city
    num_samples = {'Norfolk': 3,
                      'Arlington': 3,
                      'Atlanta': 3,
                      'Austin': 3,
                      'Seekonk': 3,
                      'NewHaven': 2} # number of samples for each city
    # set evaluation data
    if args.evaluation_data == 'SP':
        IMG_MEAN = np.array((121.68045527, 132.14961763, 129.30317439),
                        dtype=np.float32)  # mean of solar panel data in BGR order
        IMG_MEAN = [IMG_MEAN[2], IMG_MEAN[1], IMG_MEAN[0]] # convert to RGB order

        # valid_list = [ '11ska505665{}', '11ska580710{}', '11ska475635{}', '11ska475875{}', '11ska565905{}', '11ska490860{}', '11ska325740{}', '11ska460725{}', '11ska490605{}', '11ska430815{}', '11ska400740{}', '11ska580875{}', '11ska655725{}', '11ska595860{}', '11ska460890{}', '11ska655695{}', '11ska640605{}', '11ska580605{}', '11ska595665{}', '11ska505755{}', '11ska475650{}', '11ska595755{}', '11ska625755{}', '11ska490740{}', '11ska565755{}', '11ska520725{}', '11ska595785{}', '11ska580755{}', '11ska445785{}', '11ska625710{}', '11ska520830{}', '11ska640800{}', '11ska535785{}', '11ska430905{}', '11ska505695{}', '11ska565770{}']
        # valid_list = ['11ska580860{}', '11ska565845{}']
        valid_list = ['11ska625680{}', '11ska610860{}', '11ska445890{}', '11ska520695{}', '11ska355800{}', '11ska370755{}',
                  '11ska385710{}', '11ska550770{}', '11ska505740{}', '11ska385800{}', '11ska655770{}', '11ska385770{}',
                  '11ska610740{}', '11ska550830{}', '11ska625830{}', '11ska535740{}', '11ska520815{}', '11ska595650{}',
                  '11ska475665{}', '11ska520845{}']

    elif args.training_data in citylist:
        IMG_MEAN = image_mean_list[args.training_data]
        IMG_MEAN = [IMG_MEAN[2], IMG_MEAN[1], IMG_MEAN[0]] # convert to RGB order
        valid_list = ["{}_{:0>2}{{}}".format(args.evaluation_data, i) for i in range(1,num_samples[args.evaluation_data]+1)]

    else:
        print("Wrong data option: {}".format(args.training_data))

    # set image mean

    # setup used GPU
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU

    """Create the model and start the evaluation process."""

    # data reader.

    # input image
    input_img = tf.placeholder(tf.float32, shape=[None, args.image_size, args.image_size, 3], name="input_image")
    # img = tf.image.decode_jpeg(tf.read_file(args.img_path), channels=3)
    # Convert RGB to BGR.
    img_r, img_g, img_b = tf.split(axis=3, num_or_size_splits=3, value=input_img)
    img = tf.cast(tf.concat(axis=3, values=[img_b, img_g, img_r]), dtype=tf.float32)
    # Extract mean.

    # Create network.
    net = DeepLabResNetModel({'data': img}, is_training=False, num_classes=args.num_classes)

    # Which variables to load.
    restore_var = tf.global_variables()

    # Predictions.
    res5c_relu = net.layers['res5c_relu']
    fc1_voc12_c0 = net.layers['fc1_voc12_c0']
    fc1_voc12_c1 = net.layers['fc1_voc12_c1']
    fc1_voc12_c2 = net.layers['fc1_voc12_c2']
    fc1_voc12_c3 = net.layers['fc1_voc12_c3']

    raw_output = net.layers['fc1_voc12']

    raw_output_up = tf.image.resize_bilinear(raw_output, tf.shape(img)[1:3, ])
    # raw_output_up_argmax = tf.argmax(raw_output_up, dimension=3)
    # pred = tf.expand_dims(raw_output_up_argmax, dim=3)
    pmap = tf.nn.softmax(raw_output_up, name="probability_map")

    # Set up TF session and initialize variables. 
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)

    # Load weights.
    loader = tf.train.Saver(var_list=restore_var)

    if os.path.isdir(args.restore_from):
        # search checkpoint at given path
        ckpt = tf.train.get_checkpoint_state(args.restore_from)
        if ckpt and ckpt.model_checkpoint_path:
            # load checkpoint file
            load(loader, sess, ckpt.model_checkpoint_path)
            file = open(os.path.join(args.restore_from, 'test.csv'), 'a')
            file.write("\nTest Model: {}\ntransfer_mean:{}\n".format(ckpt.model_checkpoint_path, args.is_mean_transfer))
        else:
            print("No model found at{}".format(args.restore_from))
            sys.exit()
    elif os.path.isfile(args.restore_from):
        # load checkpoint file
        load(loader, sess, args.restore_from)
        file = open(os.path.join(args.restore_from, 'test.csv'), 'a')
        file.write("\nTest Model: {}\ntransfer_mean:{}\n".format(args.restore_from, args.is_mean_transfer))
    else:
        print("No model found at{}".format(args.restore_from))
        sys.exit()

    '''Perform evaluation on large images.'''
    # preds, scoremap, pmap, cnn_out, fc0, fc1, fc2, fc3 = sess.run([pred, raw_output, raw_output_up, res5c_relu, fc1_voc12_c0, fc1_voc12_c1, fc1_voc12_c2, fc1_voc12_c3], feed_dict={input_img})

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # gaussian weight kernel
    gfilter = gauss2D(shape=[args.image_size, args.image_size], sigma=(args.image_size - 1) / 4)

    seg_metric = SegMetric(1)

    valid_stride = int(args.image_size/2)


    for valid_file in valid_list:
        print("evaluate image {}".format(valid_file[0:-2]))
        if args.evaluation_data == 'SP':
            valid_image = misc.imread(os.path.join(args.img_path, valid_file.format('.png')))
        else:
            valid_image = misc.imread(os.path.join(args.img_path, valid_file.format('_RGB.png')))
        valid_truth = (misc.imread(os.path.join(args.img_path, valid_file.format('_truth.png')))/255).astype(np.uint8)

        valid_image = misc.imresize(valid_image, args.resolution_ratio, interp='bilinear')
        valid_truth = misc.imresize(valid_truth, args.resolution_ratio, interp='nearest')

        if args.is_mean_transfer:
            IMG_MEAN = np.mean(valid_image, axis=(0,1)) # Image mean of testing data

        valid_image = valid_image - IMG_MEAN # substract mean from image

        image_shape = valid_truth.shape

        valid_patches = patchify(valid_image, args.image_size, valid_stride)
        """divided patches into smaller batch for evaluation"""
        pred_pmap = valid_in_batch(valid_patches, sess, pmap, input_img, step=args.batch_size)

        # pred_pmap = np.ones(valid_patches.shape[0:-1])

        print("Stiching patches")
        pred_pmap_weighted = pred_pmap * gfilter[None, :, :]
        pred_pmap_weighted_large = unpatchify(pred_pmap_weighted, image_shape, valid_stride)
        gauss_mask_large = unpatchify(np.ones(pred_pmap.shape) * gfilter[None, :, :], image_shape, valid_stride)
        pred_pmap_weighted_large_normalized = np.nan_to_num(pred_pmap_weighted_large / gauss_mask_large)
        pred_binary = (pred_pmap_weighted_large_normalized > 0.5).astype(np.uint8)
        
        print("Save evaluation prediction")

        misc.imsave(os.path.join(args.save_dir, '{}_valid_pred.png'.format(valid_file[0:-2])), pred_binary)
        misc.imsave(os.path.join(args.save_dir, '{}_valid_pred_255.png'.format(valid_file[0:-2])), pred_binary*255)
        misc.toimage(pred_pmap_weighted_large_normalized.astype(np.float32), high=1.0, low=0.0, cmin=0.0, cmax=1.0, mode='F').save(
            os.path.join(args.save_dir, '{}_valid_pmap.tif'.format(valid_file[0:-2])))

        # mean IoU
        seg_metric.add_image_pair(pred_binary, valid_truth)
        message_temp = "{}, {:.4f}".format(valid_file[0:-2], mean_IU(pred_binary, valid_truth))
        print(message_temp)
        file.write(message_temp+'\n')
        # # Plot PR curve
        # precision, recall, thresholds = precision_recall_curve(valid_truth.flatten(), pred_pmap_weighted_large_normalized.flatten(), 1)
        # plt.figure()
        # plt.plot(recall, precision, lw=2, color='navy',
        #          label='Precision-Recall curve')
        # plt.xlabel('Recall')
        # plt.ylabel('Precision')
        # plt.ylim([0.0, 1.05])
        # plt.xlim([0.0, 1.0])
        # plt.title('Precision-Recal')
        # # plt.legend(loc="lower left")
        # plt.savefig(os.path.join(args.save_dir, '{}_PR_curve.png'.format(valid_file[0:-2])))

    # msk = decode_labels(preds, num_classes=args.num_classes)
    # im = Image.fromarray(msk[0])

    # im.save(args.save_dir + 'pred.png')
    message_overall = "Overall, {:.4f}".format(seg_metric.mean_IU())
    print(message_overall)
    file.write(message_overall + '\n')
    file.close()
    print('The output file has been saved to {}'.format(args.save_dir))

    sess.close()