Esempio n. 1
0
def run_eval_mend():
    img = cv2.imread('road-car.png')[np.newaxis, :, :, :]
    img = np.pad(img, ((0, 0), (32, 32), (32, 32), (0, 0)), 'reflect')
    # mask = cv2.imread('road-label.png')[np.newaxis, :, :, :]
    mask = cv2.imread('road-cloud0.png')[np.newaxis, :, :, :]
    mask = np.pad(mask, ((0, 0), (32, 32), (32, 32), (0, 0)), 'reflect')[:, :, :, 0:1]

    threshold = 244
    mask[mask < threshold] = 0
    mask[mask >= threshold] = 255

    # cv2.imshow('', mask[0])
    # cv2.waitKey(5432)
    eval_list = [img, mask, img, mask]

    from mod_mend_dila import init_train
    C = Config('mod_mend_dila')
    from mod_mend_nres import init_train
    C = Config('mod_mend_nres')
    inp_ground, inp_mask01, inp_grdbuf, inp_mskbuf, fetch, eval_fetch = init_train()

    C.size = img.shape[1]
    sess = mod_util.get_sess(C)
    mod_util.get_saver_logger(C, sess)
    print("||Training Check")
    eval_feed_dict = {inp_ground: eval_list[0],
                      inp_mask01: eval_list[1],
                      inp_grdbuf: eval_list[2],
                      inp_mskbuf: eval_list[3], }
    img_util.get_eval_img(mat_list=sess.run(eval_fetch, eval_feed_dict), channel=3,
                          img_path="%s/eval-%08d.jpg" % ('temp', 0))
Esempio n. 2
0
def cloud_detect(aerials):
    import tensorflow as tf
    from configure import Config
    from utils import mod_util
    from mod_cloud_detect import unet
    C = Config('mod_cloud_detect')

    size = aerials.shape[1]

    unet_name, unet_dim = 'unet', 24
    inp_aerial = tf.placeholder(tf.uint8, [None, size, size, 3])
    ten_aerial = tf.to_float(inp_aerial) / 255
    eva_grdcld = unet(ten_aerial, unet_dim, 4, unet_name, reuse=False, training=False)
    eva_ground = eva_grdcld[:, :, :, 0:3]
    eva_mask01 = eva_grdcld[:, :, :, 3:4]

    sess = mod_util.get_sess(C)
    mod_util.get_saver_logger(C, sess)
    print("||Training Check")

    # aerials_shape = list(aerials.shape[-1:])
    # aerials_shape = [-1, 16] + aerials_shape
    # aerials = aerials.reshape(aerials_shape)

    grounds = list()
    mask01s = list()
    for i, aerial in enumerate(aerials):
        eval_feed_dict = {inp_aerial: aerial[np.newaxis, :, :, :]}
        # eval_fetch = [ten_aerial, eva_ground, eva_mask01]
        eval_fetch = [eva_ground, eva_mask01]
        mat_list = sess.run(eval_fetch, eval_feed_dict)

        grounds.append(np.clip(mat_list[0] * 255, 0, 255).astype(np.uint8))
        mask01s.append(np.clip(mat_list[1] * 255, 0, 255).astype(np.uint8))

        # img_util.get_eval_img(mat_list=mat_list,channel=3, img_write=False
        #                                  img_path="%s/eval-%08d.jpg" % ('temp', 0),)
        if rd.rand() < 0.01:
            print('Eval:', i)

    # def mats_list2jpg(mats_list, save_name):
    #     mats = np.concatenate(mats_list, axis=0)
    #     img = img_grid_reverse(mats)
    #     cv2.imwrite(save_name, img)
    #
    # mats_list2jpg(grounds, 'su_zhou/ground.jpg')
    # mats_list2jpg(mask01s, 'su_zhou/mask01.jpg')
    grounds = np.concatenate(grounds, axis=0)
    mask01s = np.concatenate(mask01s, axis=0)
    return grounds, mask01s
Esempio n. 3
0
def run_eval_haze():
    img = cv2.imread('road-thin.png')[np.newaxis, :, :, :]
    img = np.pad(img, ((0, 0), (32, 32), (32, 32), (0, 0)), 'reflect')
    eval_list = [img, np.zeros_like(img[:, :, :, 0:1])]
    from mod_haze_unet import init_train
    inp_ground, inp_mask01, train_fetch, eval_fetch = init_train()

    C = Config('mod_haze_unet')
    C.size = img.shape[1]
    sess = mod_util.get_sess(C)
    mod_util.get_saver_logger(C, sess)
    print("||Training Check")
    eval_feed_dict = {inp_ground: eval_list[0],
                      inp_mask01: eval_list[1], }
    img_util.get_eval_img(mat_list=sess.run(eval_fetch, eval_feed_dict), channel=3,
                          img_path="%s/eval-%08d.jpg" % ('temp', 0))
Esempio n. 4
0
def cloud_removal(aerials, label1s):
    import tensorflow as tf
    from configure import Config
    from utils import mod_util
    from mod_cloud_remove_rec import auto_encoder
    C = Config('mod_cloud_remove_rec')

    size = aerials.shape[1]

    gene_name, gene_dim = 'gene', 32
    inp_ground = tf.placeholder(tf.uint8, [None, size, size, 3])
    ten_ground = tf.to_float(inp_ground) / 255
    inp_mask01 = tf.placeholder(tf.uint8, [None, size, size, 1])
    ten_mask01 = tf.to_float(inp_mask01) / 255

    ten_mask10 = (1.0 - ten_mask01)
    ten_ragged = ten_ground * ten_mask10

    ten_patch3 = auto_encoder(ten_ragged - ten_mask01,
                              gene_dim, 3, gene_name,
                              reuse=False, training=False)
    out_ground = ten_ragged + ten_patch3 * ten_mask01

    sess = mod_util.get_sess(C)
    mod_util.get_saver_logger(C, sess)
    print("||Training Check")

    patch3s = list()
    grounds = list()

    for i, (aerial, label1) in enumerate(zip(aerials, label1s)):
        aerial = aerial[np.newaxis, :, :, :]
        label1 = label1[np.newaxis, :, :, 0:1]

        eval_feed_dict = {inp_ground: aerial,
                          inp_mask01: label1, }
        eval_fetch = [ten_patch3, out_ground]
        mat_list = sess.run(eval_fetch, eval_feed_dict)

        patch3s.append(np.clip(mat_list[0] * 255, 0, 255).astype(np.uint8))
        grounds.append(np.clip(mat_list[1] * 255, 0, 255).astype(np.uint8))
        if i % 64 == 0:
            print('Eval:', i)

    grounds = np.concatenate(grounds, axis=0)
    patch3s = np.concatenate(patch3s, axis=0)
    return grounds, patch3s,
def process_train(feed_queue, buff_queue):
    print("||Training Initialize")
    inp_ground, inp_mask01, inp_grdbuf, inp_mskbuf, fetch, eval_fetch = init_train(
    )
    optz_gene = fetch[3][0]
    optz_disc = fetch[3][1]
    loss_gene = loss_disc = 0

    sess = mod_util.get_sess(C)
    saver, logger, pre_epoch = mod_util.get_saver_logger(C, sess)
    print("||Training Check")
    eval_list = feed_queue.get()
    eval_feed_dict = {
        inp_ground: eval_list[0],
        inp_mask01: eval_list[1],
        inp_grdbuf: eval_list[2],
        inp_mskbuf: eval_list[3],
    }
    sess.run(eval_fetch, eval_feed_dict)

    print("||Training Start")
    start_time = show_time = eval_time = time.time()
    try:
        for epoch in range(C.train_epoch):
            batch_losses = list()  # init
            for i in range(C.batch_size):
                batch_data = feed_queue.get()

                idx = batch_data[0]
                batch_dict = {
                    inp_ground: batch_data[1],
                    inp_mask01: batch_data[2],
                    inp_grdbuf: batch_data[3],
                    inp_mskbuf: batch_data[4],
                }

                if loss_disc * 8 < loss_gene:
                    fetch[3] = optz_gene
                # elif loss_gene * 8 < loss_disc:
                #     fetch[3] = optz_disc
                else:
                    fetch[3] = (optz_gene, optz_disc)

                buf_ground, buf_mask01, (loss_gene,
                                         loss_disc), optz = sess.run(
                                             fetch, batch_dict)
                batch_losses.append((loss_gene, loss_disc))
                buff_queue.put((idx, buf_ground, buf_mask01))

            loss_average = np.mean(batch_losses, axis=0)
            logger.write('%e %e\n' % (loss_average[0], loss_average[1]))

            if time.time() - show_time > C.show_gap:
                show_time = time.time()
                remain_epoch = C.train_epoch - epoch
                remain_time = (show_time -
                               start_time) * remain_epoch / (epoch + 1)
                print(end="\n|  %3d s |%3d epoch | Loss: %9.3e %9.3e" %
                      (remain_time, remain_epoch, loss_average[0],
                       loss_average[1]))
            if time.time() - eval_time > C.eval_gap:
                eval_time = time.time()
                logger.close()
                logger = open(C.model_log, 'a')

                eval_feed_dict[inp_mask01] = np.rot90(
                    eval_feed_dict[inp_mask01], axes=(1, 2))
                img_util.get_eval_img(mat_list=sess.run(
                    eval_fetch, eval_feed_dict),
                                      channel=3,
                                      img_path="%s/eval-%08d.jpg" %
                                      (C.model_dir, pre_epoch + epoch))
                print(end="  EVAL %d" % (pre_epoch + epoch))

            if os.path.exists(os.path.join(C.model_dir, 'Save.mark')):
                os.remove(os.path.join(C.model_dir, 'Save.mark'))
                print("\n||Break Training and save:", process_train.__name__)
                break

    except KeyboardInterrupt:
        print("\n||Break Training and save:", process_train.__name__)
    print('\n  TimeUsed:    %d' % int(time.time() - start_time))
    saver.save(sess, C.model_path, write_meta_graph=False)
    print("  SAVE: %s" % C.model_path)
    img_util.get_eval_img(mat_list=sess.run(eval_fetch, eval_feed_dict),
                          channel=3,
                          img_path="%s/eval-%08d.jpg" % (C.model_dir, 0))

    logger.close()
    sess.close()

    os.rmdir(os.path.join(C.model_dir, 'TRAINING.MARK'))
Esempio n. 6
0
def process_train(feed_queue):
    print("||Training Initialize")
    inp_ground, inp_mask01, train_fetch, eval_fetch = init_train()

    sess = mod_util.get_sess(C)
    saver, logger, pre_epoch = mod_util.get_saver_logger(C, sess)

    print("||Training Check")
    eval_list = feed_queue.get()
    eval_feed_dict = {
        inp_ground: eval_list[0],
        inp_mask01: eval_list[1],
    }
    sess.run(eval_fetch, eval_feed_dict)

    print("||Training Start")
    start_time = show_time = eval_time = time.time()
    loss = (0, 0)
    try:
        for epoch in range(C.train_epoch):
            for i in range(C.batch_size):
                batch_data = feed_queue.get()
                batch_dict = {
                    inp_ground: batch_data[0],
                    inp_mask01: batch_data[1],
                }
                loss, optz = sess.run(train_fetch, batch_dict)
            logger.write('%e %e\n' % (loss[0], loss[1]))

            if time.time() - show_time > C.show_gap:
                show_time = time.time()
                remain_epoch = C.train_epoch - epoch
                remain_time = (show_time -
                               start_time) * remain_epoch / (epoch + 1)
                print(end="\n%3d s |%3d epoch | Loss: %9.3e %9.3e" %
                      (remain_time, remain_epoch, loss[0], loss[1]))

            if time.time() - eval_time > C.eval_gap:
                eval_time = time.time()
                logger.close()  # write info the disk
                logger = open(C.model_log, 'a')

                eval_feed_dict[inp_mask01] = np.rot90(
                    eval_feed_dict[inp_mask01], axes=(1, 2))
                img_util.get_eval_img(mat_list=sess.run(
                    eval_fetch, eval_feed_dict),
                                      channel=3,
                                      img_path="%s/eval-%08d.jpg" %
                                      (C.model_dir, pre_epoch + epoch))
                print(end="  EVAL %d" % (pre_epoch + epoch))

            if os.path.exists(os.path.join(C.model_dir, 'SAVE.MARK')):
                os.remove(os.path.join(C.model_dir, 'SAVE.MARK'))
                print("\n||Break Training and save:", process_train.__name__)
                break

    except KeyboardInterrupt:
        print("\n||Break Training and save:", process_train.__name__)
    print('\n  TimeUsed:    %d' % int(time.time() - start_time))
    saver.save(sess, C.model_path, write_meta_graph=False)
    print("  SAVE: %s" % C.model_path)
    img_util.get_eval_img(mat_list=sess.run(eval_fetch, eval_feed_dict),
                          channel=3,
                          img_path="%s/eval-%08d.jpg" % (C.model_dir, 0))

    logger.close()
    sess.close()

    os.rmdir(os.path.join(C.model_dir, 'TRAINING.MARK'))
Esempio n. 7
0
def run_eval():
    mat_list = list()
    for name in ['ground', 'out_aer', 'out_cld']:
        img = cv2.imread('eval_replace_%s.jpg' % name)
        mat_list.append(img[np.newaxis, :, :, :])

    import tensorflow as tf
    import mod_mend_buff as mod
    import os
    class Config(object):
        train_epoch = 2 ** 14
        train_size = int(2 ** 17 * 1.9)
        eval_size = 2 ** 3
        batch_size = 2 ** 4
        batch_epoch = train_size // batch_size

        size = int(2 ** 9)  # size = int(2 ** 7)
        replace_num = int(0.25 * batch_size)
        learning_rate = 1e-5  # 1e-4

        show_gap = 2 ** 5  # time
        eval_gap = 2 ** 9  # time
        gpu_limit = 0.48  # 0.0 ~ 1.0
        gpu_id = 1

        data_dir = '/mnt/sdb1/data_sets'
        aerial_dir = os.path.join(data_dir, 'AerialImageDataset/train')
        cloud_dir = os.path.join(data_dir, 'ftp.nnvl.noaa.gov_color_IR_2018')
        grey_dir = os.path.join(data_dir, 'CloudGreyDataset')

        def __init__(self, model_dir='mod'):
            self.model_dir = model_dir
            self.model_name = 'mod'
            self.model_path = os.path.join(self.model_dir, self.model_name)
            self.model_npz = os.path.join(self.model_dir, self.model_name + '.npz')
            self.model_log = os.path.join(self.model_dir, 'training_npy.txt')

    C = Config('mod_mend_GAN_buff')

    gene_name = 'gene'
    inp_ground = tf.placeholder(tf.uint8, [None, C.size, C.size, 3])
    inp_cloud1 = tf.placeholder(tf.uint8, [None, C.size, C.size, 1])

    flt_ground = tf.to_float(inp_ground) / 255.0
    flt_cloud1 = tf.to_float(inp_cloud1) / 255.0
    ten_repeat = tf.ones([1, 1, 1, 3])

    ten_ground = flt_ground[:C.batch_size]
    # buf_ground = flt_ground[C.batch_size:]
    ten_cloud1 = flt_cloud1[:C.batch_size]

    ten_cloud3 = ten_cloud1 * ten_repeat
    ten_mask10 = (1.0 - ten_cloud3)
    ten_ragged = ten_ground * ten_mask10

    ten_patch3 = mod.generator(tf.concat((ten_ragged, ten_cloud3), axis=3),
                               32, 3, gene_name, reuse=False)
    out_ground = ten_ragged + ten_patch3 * ten_cloud3

    from utils import mod_util
    sess = mod_util.get_sess(C)
    saver, logger, pre_epoch = mod_util.get_saver_logger(C, sess)
    print("||Training Check")
    # eval_fetch = [ten_ground, out_ground, ten_patch3, ten_cloud3]
    eval_fetch = [out_ground, ten_patch3]
    eval_feed_dict = {inp_ground: mat_list[0],
                      inp_cloud1: mat_list[2][:, :, :, 0:1]}

    mat_list = sess.run(eval_fetch, eval_feed_dict)
    for img, name in zip(mat_list, ['out_ground', 'ten_patch3']):
        img = (img[0] * 255).astype(np.uint8)
        cv2.imshow('beta', img)
        cv2.waitKey(4321)
        print(img.shape, np.max(img))
        cv2.imwrite('eval_gan_%s.jpg' % name, img)

    print(end="  EVAL")