예제 #1
0
def mc_test(file_name, number):
    data = np.load('/home/ws/文档/wrj/data_all/' + file_name + '.npz')
    train_matching_y = data['arr_1'][:, np.newaxis]
    numbers_train = train_matching_y.shape[0]  #训练集总数

    graph = tf.Graph()
    with graph.as_default():
        inputs_p1 = tf.placeholder(
            tf.float32, [BATCH_SIZE_matching, image_height, image_width, 1],
            name='inputs_p1')
        inputs_p2 = tf.placeholder(
            tf.float32, [BATCH_SIZE_matching, image_height, image_width, 1],
            name='inputs_p2')
        label_m = tf.placeholder(tf.float32, [BATCH_SIZE_matching, 1],
                                 name='label_m')

        match_loss, match_output, all_features = model.match_network(
            inputs_p1, inputs_p2, label_m)
        match_out = tf.round(match_output)
        m_correct, m_numbers = model.evaluation(match_out, label_m)

        filename = '/home/ws/文档/wrj/data_all/country/' + file_name + '.tfrecord'
        filename_queue = tf.train.string_input_producer([filename],
                                                        num_epochs=1,
                                                        shuffle=False)
        img_batch, label_batch = read_data.batch_inputs(
            filename_queue, train=False, batch_size=BATCH_SIZE_matching)

        saver = tf.train.Saver()

        gpu_options = tf.GPUOptions(allow_growth=True)
        sess_config = tf.ConfigProto(gpu_options=gpu_options)

        with tf.Session(config=sess_config) as sess:
            sess.run(tf.local_variables_initializer())
            saver.restore(sess, tf.train.latest_checkpoint('ckpt_map32_diff'))
            #            saver.restore(sess, 'ckpt_tensor/model.ckpt-4000')
            m_count = 0  # Counts the number of correct predictions.
            num = numbers_train
            matching_out = np.array([])
            matching_label = np.array([])

            try:
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)
                step_test = 0
                try:
                    while not coord.should_stop():
                        step_test = step_test + 1
                        batch, l_batch = sess.run([img_batch, label_batch])
                        l_batch = l_batch
                        x_batch = batch[:, :, :64, np.newaxis]
                        y_batch = batch[:, :, 64:, np.newaxis]
                        feed_dict = {
                            inputs_p1: x_batch,
                            inputs_p2: y_batch,
                            label_m: l_batch
                        }
                        m_correct_, m_output_ = sess.run(
                            [m_correct, match_output], feed_dict=feed_dict)

                        if step_test == 1:
                            matching_out = m_output_
                            matching_label = l_batch
                        elif (l_batch.size == BATCH_SIZE_matching):
                            matching_out = np.concatenate(
                                (matching_out, m_output_))
                            matching_label = np.concatenate(
                                (matching_label, l_batch))

                        m_count = m_count + m_correct_.astype(int)
                        if step_test % 100 == 0:
                            print(
                                'Step %d/%d  run_test: batch_precision = %.2f  '
                                % (step_test, num / 100,
                                   m_correct_ / BATCH_SIZE_matching))
                except Exception as e:
                    coord.request_stop(e)
                m_precision = float(m_count) / num

                print(
                    'Num examples: %d  Num correct: %d  match Precision : %0.04f  '
                    % (num, m_count, m_precision))
                save_file = open('test_map32_diff.txt', 'a')
                save_file.write(file_name + '    epoch: ' + str(number) +
                                '\n' + 'num correct: ' + str(m_count) + '/' +
                                str(num) + '     match precision : ' +
                                str(m_precision))
                save_file.write('    ')

                #       绘制ROC曲线
                fpr, tpr, threshold = roc_curve(matching_label,
                                                matching_out)  ###计算真正率和假正率

                q = np.where(0.95 <= tpr)
                q_value = q[0][0]
                fpr95 = fpr[q_value]

                save_file.write('match fpr95 : ' + str(fpr95 * 100))
                save_file.write('\n')
                save_file.write('\n')
                save_file.close()
                roc_dir = 'plot_curve/roc/nature_map32_diff/epoch_' + str(
                    number)
                try:
                    os.makedirs(roc_dir)
                except os.error:
                    pass
                np.savez(roc_dir + '/match_map32_diff_' + file_name, fpr, tpr)
            except KeyboardInterrupt:
                print('INTERRUPTED')
                coord.request_stop()
            finally:
                # When done, ask the threads to stop.
                coord.request_stop()
                coord.join(threads)
예제 #2
0
def mc_train():
    initLogging()
    current_time = datetime.now().strftime('%Y-%m-%d')
    try:
        os.makedirs(checkpoint_dir)
    except os.error:
        pass

    data1 = np.load('/home/ws/文档/wrj/data_all/country.npz')
    train_matching_y = data1['arr_1'][:, np.newaxis]
    numbers_train = train_matching_y.shape[0]  #训练集总数
    epoch_steps = np.int(
        numbers_train / BATCH_SIZE_matching) + 1  # 一个epoch有多少个steps

    all_loss = np.array([])
    graph = tf.Graph()
    with graph.as_default():
        inputs_p1 = tf.placeholder(
            tf.float32, [BATCH_SIZE_matching, image_height, image_width, 1],
            name='inputs_p1')
        inputs_p2 = tf.placeholder(
            tf.float32, [BATCH_SIZE_matching, image_height, image_width, 1],
            name='inputs_p2')
        label_m = tf.placeholder(tf.float32, [BATCH_SIZE_matching, 1],
                                 name='label_m')

        # 训练 M
        match_loss, match_output, all_features = model.match_network(
            inputs_p1, inputs_p2, label_m)
        match_out = tf.round(match_output)
        m_correct, m_numbers = model.evaluation(match_out, label_m)

        m_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(
            match_loss)

        filename = '/home/ws/文档/wrj/data_all/country/country.tfrecord'
        filename_queue = tf.train.string_input_producer([filename],
                                                        num_epochs=epoch,
                                                        shuffle=True)
        img_batch, label_batch = read_data.batch_inputs(
            filename_queue, train=True, batch_size=BATCH_SIZE_matching)
        tf.summary.scalar('mathing_loss', match_loss)

        summary = tf.summary.merge_all()
        saver = tf.train.Saver(max_to_keep=20)
        init = tf.global_variables_initializer()

        gpu_options = tf.GPUOptions(allow_growth=True)
        sess_config = tf.ConfigProto(gpu_options=gpu_options)

        with tf.Session(config=sess_config) as sess:
            summary_writer = tf.summary.FileWriter(train_dir, sess.graph)
            sess.run(tf.local_variables_initializer())
            sess.run(init)

            try:
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)
                step = 0

                while not coord.should_stop():
                    start_time = time.time()
                    step = step + 1
                    batch, l_batch = sess.run([img_batch, label_batch])
                    l_batch = l_batch
                    x_batch = batch[:, :, :64, np.newaxis]
                    y_batch = batch[:, :, 64:, np.newaxis]
                    feed_dict = {
                        inputs_p1: x_batch,
                        inputs_p2: y_batch,
                        label_m: l_batch
                    }
                    _, m_loss, m_output_ = sess.run(
                        [m_train_opt, match_loss, match_output],
                        feed_dict=feed_dict)

                    if step % 10 == 0:
                        loss_write = np.array([[step, m_loss]])
                        if step == 10:
                            all_loss = loss_write
                        else:
                            all_loss = np.concatenate((all_loss, loss_write))

                    if step % 100 == 0:
                        duration = time.time() - start_time
                        summary_str = sess.run(summary, feed_dict=feed_dict)
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()

                        logging.info(
                            '>> Step %d run_train: matching_loss = %.3f  (%.3f sec)'
                            % (step, m_loss, duration))

                    if (step % epoch_steps == 0) and ((step / epoch_steps) % 3
                                                      == 0):
                        current_epoch = int(step / epoch_steps)

                        if current_epoch > 25:
                            logging.info('>> %s Saving in %s' %
                                         (datetime.now(), checkpoint_dir))
                            saver.save(sess,
                                       checkpoint_file,
                                       global_step=current_epoch)
                            mc_test_all(current_epoch)

            except KeyboardInterrupt:
                print('INTERRUPTED')
                coord.request_stop()
            except Exception as e:
                coord.request_stop(e)

            finally:
                saver.save(sess, checkpoint_file, global_step=step)
                np.save(os.path.join(checkpoint_dir, 'ckpt_map32_diff'),
                        all_loss)
                print('Model saved in file :%s' % checkpoint_dir)
                # When done, ask the threads to stop.
                coord.request_stop()
                coord.join(threads)
def mc_test(file_name, number):
    data = np.load('/home/ws/文档/wrj/data_all_test/test_' + file_name + '.npz')
    train_matching_y = data['arr_1'][:, np.newaxis]
    numbers_train = train_matching_y.shape[0]  #训练集总数

    graph = tf.Graph()
    with graph.as_default():
        inputs_p1 = tf.placeholder(
            tf.float32, [BATCH_SIZE_matching, image_height, image_width, 1],
            name='inputs_gray')
        inputs_p2 = tf.placeholder(
            tf.float32, [BATCH_SIZE_matching, image_height, image_width, 1],
            name='inputs_nir')
        label_m = tf.placeholder(tf.float32, [BATCH_SIZE_matching, 1],
                                 name='label_m')

        inputs_p1_ = model_g.preprocess(inputs_p1)
        g_outputs = model_g.create_generator(inputs_p1_, 1)
        g_outputs_ = model_g.deprocess(g_outputs)
        # 训练 M
        match_loss, match_output, all_features = model.match_network(
            g_outputs_, inputs_p2, label_m)
        match_out = tf.round(match_output)
        m_correct, m_numbers = model.evaluation(match_out, label_m)
        #        /home/ws/文档/wrj/data_all_test/test_data/test_
        filename = '/home/ws/文档/wrj/data_all_test/test_data/test_' + file_name + '.tfrecord'
        #        filename = '/home/ws/文档/wrj/data_all/country/'+file_name+'.tfrecord'
        filename_queue = tf.train.string_input_producer([filename],
                                                        num_epochs=1,
                                                        shuffle=False)
        img_batch, label_batch = read.batch_inputs(
            filename_queue, train=False, batch_size=BATCH_SIZE_matching)

        gen_tvars = [
            var for var in tf.trainable_variables()
            if var.name.startswith("generator")
        ]
        saver_g = tf.train.Saver(var_list=gen_tvars)
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(tf.local_variables_initializer())
            #            saver.restore(sess, tf.train.latest_checkpoint('ckpt_map_gd'))
            saver.restore(sess, 'ckpt_map_gd/model.ckpt-27')
            saver_g.restore(
                sess, tf.train.latest_checkpoint(checkpoint_dir_g + '_all'))
            #            saver.restore(sess, 'ckpt_tensor/model.ckpt-4000')
            m_count = 0  # Counts the number of correct predictions.
            num = numbers_train
            matching_out = np.array([])
            matching_label = np.array([])

            try:
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)
                step_test = 0
                try:
                    while not coord.should_stop():
                        step_test = step_test + 1
                        batch, l_batch = sess.run([img_batch, label_batch])
                        l_batch = l_batch
                        x_batch = batch[:, :, :64, np.newaxis]
                        y_batch = batch[:, :, 64:, np.newaxis]
                        feed_dict = {
                            inputs_p1: x_batch,
                            inputs_p2: y_batch,
                            label_m: l_batch
                        }
                        m_correct_, m_output_ = sess.run(
                            [m_correct, match_output], feed_dict=feed_dict)

                        if step_test == 1:
                            matching_out = m_output_
                            matching_label = l_batch
                        elif (l_batch.size == BATCH_SIZE_matching):
                            matching_out = np.concatenate(
                                (matching_out, m_output_))
                            matching_label = np.concatenate(
                                (matching_label, l_batch))

                        m_count = m_count + m_correct_.astype(int)
                        if step_test % 100 == 0:
                            print(
                                'Step %d/%d  run_test: batch_precision = %.2f  '
                                % (step_test, num / 100,
                                   m_correct_ / BATCH_SIZE_matching))
                except Exception as e:
                    coord.request_stop(e)
                m_precision = float(m_count) / num

                print(
                    'Num examples: %d  Num correct: %d  match Precision : %0.04f  '
                    % (num, m_count, m_precision))
                save_file = open('test_map_gd.txt', 'a')
                save_file.write(file_name + '    epoch: ' + str(number) +
                                '\n' + 'num correct: ' + str(m_count) + '/' +
                                str(num) + '     match precision : ' +
                                str(m_precision))
                save_file.write('    ')

                #                 计算 tp,tn,fp,fn
                pre_out = np.round(matching_out)
                pre_out = pre_out.astype(int)
                lab_out = matching_label.astype(int)
                tp = np.sum(pre_out & lab_out)
                tn = m_count - tp
                fp = np.sum(pre_out) - tp
                fn = np.sum(lab_out) - tp
                save_file.write(file_name + '    tp: ' + str(tp)+ '    tn: ' + str(tn) \
                                            + '    fp: ' + str(fp)+ '    fn: ' + str(fn))
                #       绘制ROC曲线
                fpr, tpr, threshold = roc_curve(matching_label,
                                                matching_out)  ###计算真正率和假正率

                q = np.where(0.95 <= tpr)
                q_value = q[0][0]
                fpr95 = fpr[q_value]

                save_file.write('match fpr95 : ' + str(fpr95 * 100))
                save_file.write('\n')
                save_file.write('\n')
                save_file.close()
                roc_dir = 'plot_curve/epoch_' + str(number)
                try:
                    os.makedirs(roc_dir)
                except os.error:
                    pass
                np.savez(roc_dir + '/match_map_gd_' + file_name, fpr, tpr)
            except KeyboardInterrupt:
                print('INTERRUPTED')
                coord.request_stop()
            finally:
                # When done, ask the threads to stop.
                coord.request_stop()
                coord.join(threads)