Example #1
0
def mc_test(file_name, number):

    #    data = np.load('/home/ws/文档/wrj/mapping_data_all/map_rgb_'+file_name+'.npz')
    #    train_matching_y = data['arr_0'][:]
    #    numbers_train = train_matching_y.shape[0]  #训练集总数
    BATCH_SIZE_map = 40

    graph = tf.Graph()
    with graph.as_default():
        inputs_p1 = tf.placeholder(
            tf.float32, [BATCH_SIZE_map, image_height, image_width, 1],
            name='inputs_gray')
        inputs_p2 = tf.placeholder(
            tf.float32, [BATCH_SIZE_map, image_height, image_width, 1],
            name='inputs_nir')

        gen_loss, dis_loss, _, _ = model.gd_model_g2r(inputs_p1, inputs_p2)
        inputs_p1_ = model.preprocess(inputs_p1)
        gen = model.create_generator(inputs_p1_, 1, reuse=True)
        gen = model.deprocess(gen)

        #        filename = '/home/ws/文档/wrj/mapping_data_gray/map_data/map_'+file_name+'.tfrecord'
        filename = '/home/ws/文档/wrj/data_all/country/' + file_name + '.tfrecord'
        filename_queue = tf.train.string_input_producer([filename],
                                                        num_epochs=1,
                                                        shuffle=False)
        img_batch, label_batch = read.batch_inputs(filename_queue,
                                                   train=False,
                                                   batch_size=BATCH_SIZE_map)

        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(tf.local_variables_initializer())
            #            saver.restore(sess, tf.train.latest_checkpoint('ckpt_gd_g2r_pro_all'))
            saver.restore(sess,
                          'ckpt_gd_g2r_pro_all/model.ckpt-' + str(number))
            #            num = 0

            try:
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)
                step_test = 0
                while not coord.should_stop():
                    if step_test < 1:

                        step_test = step_test + 1
                        batch, l_batch = sess.run([img_batch, label_batch])
                        x_batch = batch[:, :, :64, np.newaxis]
                        y_batch = batch[:, :, 64:, np.newaxis]
                        feed_dict = {inputs_p1: x_batch}
                        gen_out = sess.run(gen, feed_dict=feed_dict)

                        gen_out_dir = 'out_g2r_pro_all/epoch_' + str(number)
                        try:
                            os.makedirs(gen_out_dir)
                        except os.error:
                            pass
                        show_images = np.concatenate((y_batch, gen_out),
                                                     axis=1)
                        show_images = show_images * 255
                        for i in range(BATCH_SIZE_map):
                            cv2.imwrite(
                                gen_out_dir +
                                '/{}.png'.format(file_name + '_' + str(i + 1) +
                                                 "_gray"),
                                np.squeeze(x_batch[i, :, :, :]) * 255)
                            cv2.imwrite(
                                gen_out_dir +
                                '/{}.png'.format(file_name + '_' + str(i + 1) +
                                                 "_nir"),
                                np.squeeze(show_images[i, :64, :, :]))
                            cv2.imwrite(
                                gen_out_dir +
                                '/{}.png'.format(file_name + '_' + str(i + 1) +
                                                 "_fnir"),
                                np.squeeze(show_images[i, 64:, :, :]))
                    else:
                        break
            except KeyboardInterrupt:
                print('INTERRUPTED')
                coord.request_stop()
            except Exception as e:
                coord.request_stop(e)
            finally:
                coord.request_stop()
                coord.join(threads)
Example #2
0
def gd_train(f_name):
    checkpoint_dir = 'ckpt_gd_g2r_pro_' + f_name
    checkpoint_dir_g = 'ckpt_g_g2r_pro_' + f_name
    output_dir = 'out_g2r_pro_' + f_name
    checkpoint_dir_en = 'ckpt_en_r2r_' + f_name

    checkpoint_file = os.path.join(checkpoint_dir, 'model.ckpt')
    checkpoint_file_g = os.path.join(checkpoint_dir_g, 'model.ckpt')
    initLogging('record_gd_g2r_pro_' + f_name + '.log')
    #    if FLAGS.load_model is not None:
    #        checkpoints_dir = 'checkpoints/' + FLAGS.load_model
    #    else:
    current_time = datetime.now().strftime('%Y%m%d-%H%M')
    try:
        os.makedirs(checkpoint_dir_g)
        os.makedirs(checkpoint_dir)
    except os.error:
        pass
    try:
        os.makedirs(output_dir)
    except os.error:
        pass
#
    data1 = np.load('/home/ws/文档/wrj/mapping_data_gray/map_' + f_name + '.npz')
    train_matching_y = data1['arr_0'][:, np.newaxis]
    numbers_train = train_matching_y.shape[0]  #训练集总数
    epoch_steps = np.int(
        numbers_train / BATCH_SIZE_matching) + 1  # 一个epoch有多少个steps

    all_loss = np.array([])
    graph = tf.Graph()
    with graph.as_default():
        inputs_p1 = tf.placeholder(
            tf.float32, [BATCH_SIZE_matching, image_height, image_width, 1],
            name='inputs_gray')
        inputs_p2 = tf.placeholder(
            tf.float32, [BATCH_SIZE_matching, image_height, image_width, 1],
            name='inputs_nir')
        #        test_opt = tf.placeholder(tf.float32, [None, image_height, image_width, 1], name='test_opt')

        # 训练 G
        #    将输入图像从 (0,1)变换到 (-1,1)
        inputs_p1_ = model.preprocess(inputs_p1)
        inputs_p2_ = model.preprocess(inputs_p2)
        gen_loss, dis_loss, _, _ = model.gd_model_g2r(inputs_p1_, inputs_p2_)
        discrim_tvars = [
            var for var in tf.trainable_variables()
            if var.name.startswith("discriminator")
        ]
        d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(
            dis_loss, var_list=discrim_tvars)
        gen_tvars = [
            var for var in tf.trainable_variables()
            if var.name.startswith("generator")
        ]
        g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(
            gen_loss, var_list=gen_tvars)
        nirencode_tvars = [
            var for var in tf.trainable_variables()
            if var.name.startswith("nir_encode")
        ]

        #        with tf.control_dependencies([g_train_opt, d_train_opt]):
        #             gd_train_opt = tf.no_op(name='optimizers')

        #
        filename = '/home/ws/文档/wrj/mapping_data_gray/map_data/map_' + f_name + '.tfrecord'
        filename_queue = tf.train.string_input_producer([filename],
                                                        num_epochs=EPOCH,
                                                        shuffle=True)
        img_batch, label_batch = read.batch_inputs(
            filename_queue, train=True, batch_size=BATCH_SIZE_matching)

        saver = tf.train.Saver(max_to_keep=20)
        saver_g = tf.train.Saver(var_list=gen_tvars, max_to_keep=20)
        saver_en = tf.train.Saver(var_list=nirencode_tvars)
        init = tf.global_variables_initializer()
        #
        gpu_options = tf.GPUOptions(allow_growth=True)
        sess_config = tf.ConfigProto(gpu_options=gpu_options)

        with tf.Session(config=sess_config) as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(init)
            saver_en.restore(sess,
                             tf.train.latest_checkpoint(checkpoint_dir_en))
            try:
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)
                step = 0
                while not coord.should_stop():
                    start_time = time.time()
                    step = step + 1
                    batch, l_batch = sess.run([img_batch, label_batch])
                    x_batch = batch[:, :, :64, np.newaxis]
                    y_batch = batch[:, :, 64:, np.newaxis]
                    feed_dict = {inputs_p1: x_batch, inputs_p2: y_batch}
                    _, _, g_loss, d_loss = sess.run(
                        [d_train_opt, g_train_opt, gen_loss, dis_loss],
                        feed_dict=feed_dict)

                    if step % 10 == 0:
                        loss_write = np.array([[step, g_loss, d_loss]])
                        if step == 10:
                            all_loss = loss_write
                        else:
                            all_loss = np.concatenate((all_loss, loss_write))

                    if step % 100 == 0:
                        duration = time.time() - start_time
                        logging.info(
                            '>> Step %d run_train: g_loss = %.2f  d_loss = %.2f (%.3f sec)'
                            % (step, g_loss, d_loss, duration))

                    if (step % epoch_steps == 0) and ((step / epoch_steps) % 10
                                                      == 0):
                        current_epoch = int(step / epoch_steps)

                        if current_epoch >= 40:
                            logging.info('>> %s Saving in %s' %
                                         (datetime.now(), checkpoint_dir))
                            saver.save(sess,
                                       checkpoint_file,
                                       global_step=current_epoch)
                            saver_g.save(sess,
                                         checkpoint_file_g,
                                         global_step=current_epoch)


#                              mc_test_all(current_epoch)
                        if current_epoch >= 40:
                            mc_test(f_name, current_epoch)

            except KeyboardInterrupt:
                print('INTERRUPTED')
                coord.request_stop()
            except Exception as e:
                coord.request_stop(e)

            finally:
                saver.save(sess, checkpoint_file, global_step=step)
                saver_g.save(sess, checkpoint_file_g, global_step=step)
                np.save(
                    os.path.join(checkpoint_dir, 'ckpt_map_gd_pro_' + f_name),
                    all_loss)
                np.save(
                    os.path.join(checkpoint_dir_g,
                                 'ckpt_map_gd_pro_' + f_name), all_loss)
                print('Model saved in file :%s' % checkpoint_dir)
                # When done, ask the threads to stop.
                coord.request_stop()
                coord.join(threads)
def mc_draw(file_name, number):

    #    data = np.load('/media/ws/98E62262E622413C/zzzz_wrj/wrj/mapping_data_all/map_rgb_'+file_name+'.npz')
    #    train_matching_y = data['arr_0'][:]
    #    numbers_train = train_matching_y.shape[0]  #训练集总数
    BATCH_SIZE_map = 40

    graph = tf.Graph()
    with graph.as_default():
        inputs_p1 = tf.placeholder(
            tf.float32, [BATCH_SIZE_map, image_height, image_width, 1],
            name='inputs_gray')
        inputs_p2 = tf.placeholder(
            tf.float32, [BATCH_SIZE_map, image_height, image_width, 1],
            name='inputs_nir')

        inputs_p2_ = model.preprocess(inputs_p2)
        #        gen_loss, all_layers = model.gd_model_r2r(inputs_p2_)
        gen, all_layers = model.create_generator(inputs_p2_, 1, reuse=True)
        gen = model.deprocess(gen)

        filename = '/home/ws/文档/wrj/data_all/country/' + file_name + '.tfrecord'
        filename_queue = tf.train.string_input_producer([filename],
                                                        num_epochs=1,
                                                        shuffle=False)
        img_batch, label_batch = read.batch_inputs(filename_queue,
                                                   train=False,
                                                   batch_size=BATCH_SIZE_map)

        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(tf.local_variables_initializer())
            saver.restore(sess, tf.train.latest_checkpoint('ckpt_gd_r2r_all'))
            #            saver.restore(sess, 'ckpt_tensor/model.ckpt-4000')
            #            num = 0

            try:
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)
                step_test = 0
                while not coord.should_stop():
                    if step_test < 1:
                        step_test = step_test + 1
                        batch, l_batch = sess.run([img_batch, label_batch])
                        y_batch = batch[:, :, 64:, np.newaxis]
                        feed_dict = {inputs_p2: y_batch}
                        gen_out, layers_out = sess.run([gen, all_layers],
                                                       feed_dict=feed_dict)

                        gen_out_dir = 'out_r2r_country/epoch_' + str(number)
                        try:
                            os.makedirs('feature_map')
                            os.makedirs(gen_out_dir)
                        except os.error:
                            pass
                        show_images = np.concatenate((y_batch, gen_out),
                                                     axis=1)
                        show_images = show_images * 255
                        for i in range(BATCH_SIZE_map):
                            #                            cv2.imwrite(gen_out_dir+'/{}.png'.format(file_name+'_'+str(i+1)+"_opt"), np.squeeze(x_batch[i,:,:,:]*255))
                            cv2.imwrite(
                                gen_out_dir +
                                '/{}.png'.format(file_name + '_' + str(i + 1) +
                                                 "_nir"),
                                np.squeeze(show_images[i, :64, :, :]))
                            cv2.imwrite(
                                gen_out_dir +
                                '/{}.png'.format(file_name + '_' + str(i + 1) +
                                                 "_fnir"),
                                np.squeeze(show_images[i, 64:, :, :]))

                            layers = np.squeeze(layers_out[2][i, :, :, :])
                            m = 0
                            n = 0
                            s = 8
                            max_ = np.max(layers)
                            min_ = np.min(layers)
                            layers_ = (layers - min_) / (max_ - min_)
                            features = np.ones(
                                (11 * s, 12 * s), np.uint8) * 255

                            for j in range(128):
                                feature = np.squeeze(
                                    np.uint8(layers_[:, :, j] * 255))
                                features[m * s:m * s + s,
                                         n * s:n * s + s] = feature
                                n += 1
                                if (n >= 12):
                                    n = 0
                                    m += 1

                            cv2.imwrite(
                                'feature_map/{}.png'.format(file_name + '_' +
                                                            str(i + 1)),
                                features)
                    else:
                        break
            except KeyboardInterrupt:
                print('INTERRUPTED')
                coord.request_stop()
            except Exception as e:
                coord.request_stop(e)
            finally:
                coord.request_stop()
                coord.join(threads)