def gfm_test(Data_file_name):
    
    Data_file_name = Data_file_name
    Data_file = h5py.File(Data_file_name +'.mat')
    Data_file.keys() 
    patch_1_test = Data_file['test_sar_patch']  #sar
    patch_2_test= Data_file['test_opt_patch']   #opt
    
    patch_1_test = np.array(patch_1_test)  #sar
    patch_2_test= np.array(patch_2_test)   #opt
    num1 = int(patch_1_test.shape[0]/batch_size)*batch_size
    num2 = int(patch_2_test.shape[0]/batch_size)*batch_size

    graph = tf.Graph()
    with graph.as_default():
        inputs_sar = tf.placeholder(tf.float32, [batch_size, image_height, image_width, 1], name='inputs_sar')
        inputs_opt = tf.placeholder(tf.float32, [batch_size, image_height, image_width, 1], name='inputs_opt')
        inputs_lab = tf.placeholder(tf.float32, [batch_size, 1], name='inputs_lab')
        
#        g_outputs = model.create_generator(inputs_sar, 1)
        match_loss,m_output= model.gfm_modelb(inputs_sar, inputs_opt, inputs_lab)
        out = tf.round(m_output)
        correct,ram = model.evaluation(out, inputs_lab)
        
        saver = tf.train.Saver()
        with tf.Session() as sess:  
#            saver.restore(sess, tf.train.latest_checkpoint(('ckpt_fm1b_' + Data_file_name.split('_')[1])))
            saver.restore(sess, tf.train.latest_checkpoint('ckpt_fm3br_6'))
#            saver.restore(sess, 'ckpt_fm1b/model.ckpt-4000')
            all_mout = np.array([])

            for i in range(num2):  #opt
                opt1 = patch_2_test[i,:,:,:]
                opt = np.tile(opt1,(num1,1,1,1))
                sar = patch_1_test[:num1,:,:,:]
                print(i) 
                patch_1_sar = sar #sar
                patch_2_opt = opt  #opt       
                y_test = np.zeros((num1,1))
                
    #            all_lab = np.array([])
    #            num = np.size(y_test)
                shuffle_test= gfm_shuffle(1,batch_size,patch_1_sar,patch_2_opt,y_test)
                for step1, (x_batch, y_batch, l_batch) in enumerate(shuffle_test):
                    feed_dict = {inputs_sar:x_batch, inputs_opt:y_batch, inputs_lab:l_batch}
                    result, p_m = sess.run([correct, m_output], feed_dict=feed_dict)
                    if step1 == 0:
                        all_mout = p_m
                    else:
                        all_mout = np.concatenate((all_mout, p_m), axis=0)
#                    true_count = true_count + result
                    if step1 % 100 == 0:
                        print('Step %d run_test: batch_precision = %.2f '
                                          % (step1, result/batch_size)) 
                all_mout = all_mout.T
                if i == 0:
                    results = all_mout
                else:
                    results = np.concatenate((results, all_mout), axis=0)
            return results
def gfm_test():

    data1 = np.load('6_up_sift_harris_transform_train_test_data.npz')
    patch_test = data1['arr_1']
    patch_1_test = patch_test[:67000, :, :32, :]  # sar
    patch_2_test = patch_test[:67000, :, 32:, :]  # opt
    y_test = data1['arr_3'][:67000, :]

    graph = tf.Graph()
    with graph.as_default():
        inputs_sar = tf.placeholder(tf.float32,
                                    [batch_size, image_height, image_width, 1],
                                    name='inputs_sar')
        inputs_opt = tf.placeholder(tf.float32,
                                    [batch_size, image_height, image_width, 1],
                                    name='inputs_opt')
        inputs_lab = tf.placeholder(tf.float32, [batch_size, 1],
                                    name='inputs_lab')

        match_loss, m_output = model.gfm_sia_map(inputs_sar, inputs_opt,
                                                 inputs_lab)
        out = tf.round(m_output)
        correct, ram = model.evaluation(out, inputs_lab)

        saver = tf.train.Saver()
        with tf.Session() as sess:
            saver.restore(sess, tf.train.latest_checkpoint('ckpt_map+g_6'))
            #            saver.restore(sess, 'ckpt_map+g_6/model.ckpt-12000')
            true_count = 0  # Counts the number of correct predictions.
            all_mout = np.array([])
            all_lab = np.array([])
            num = np.size(y_test)
            shuffle_test = gfm_shuffle(1, batch_size, patch_1_test,
                                       patch_2_test, y_test)
            for step1, (x_batch, y_batch, l_batch) in enumerate(shuffle_test):
                feed_dict = {
                    inputs_sar: x_batch,
                    inputs_opt: y_batch,
                    inputs_lab: l_batch
                }
                result, p_out, p_ram, p_m = sess.run(
                    [correct, out, ram, m_output], feed_dict=feed_dict)
                if step1 == 0:
                    all_mout = p_m
                    all_lab = l_batch
                else:
                    all_mout = np.concatenate((all_mout, p_m), axis=0)
                    all_lab = np.concatenate((all_lab, l_batch), axis=0)

                true_count = true_count + result
                if step1 % 10 == 0:
                    print('Step %d run_test: batch_precision = %.2f ' %
                          (step1, result / batch_size))
            precision = float(true_count) / num
            print('  Num examples: %d  Num correct: %d  Precision : %0.04f' %
                  (num, true_count, precision))
def gfm_train():

    current_time = datetime.now().strftime('%Y%m%d-%H%M')
    checkpoints_dir = 'checkpoints/{}'.format(current_time)
    try:
        os.makedirs(checkpoints_dir)
    except os.error:
        pass

    data1 = np.load('6_up_sift_harris_transform_train_test_data.npz')
    patch_train = data1['arr_0']
    patch_1_train = patch_train[:200000, :, :32, :]  # sar
    patch_2_train = patch_train[:200000, :, 32:, :]  # opt
    y_train = data1['arr_2'][:200000, :]

    patch_test = data1['arr_1']
    patch_1_test = patch_test[:3000, :, :32, :]  # sar
    patch_2_test = patch_test[:3000, :, 32:, :]  # opt
    y_test = data1['arr_3'][:3000, :]

    graph = tf.Graph()
    with graph.as_default():
        inputs_sar = tf.placeholder(tf.float32,
                                    [batch_size, image_height, image_width, 1],
                                    name='inputs_sar')
        inputs_opt = tf.placeholder(tf.float32,
                                    [batch_size, image_height, image_width, 1],
                                    name='inputs_opt')
        inputs_lab = tf.placeholder(tf.float32, [batch_size, 1],
                                    name='inputs_lab')
        # 训练 M
        fake_opt = model.create_generator_1(inputs_sar, 1)
        gen_1 = [
            var for var in tf.trainable_variables()
            if var.name.startswith("generator_1")
        ]
        fake_sar = model.create_generator_2(inputs_opt, 1)
        gen_2 = [
            var for var in tf.trainable_variables()
            if var.name.startswith("generator_2")
        ]

        match_loss, m_output = model.gfm_sia_map(inputs_sar, inputs_opt,
                                                 inputs_lab)
        out = tf.round(m_output)
        correct, ram = model.evaluation(out, inputs_lab)
        m_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(
            match_loss)

        tf.summary.scalar('mathing_loss', match_loss)
        summary = tf.summary.merge_all()
        saver_g_1 = tf.train.Saver(var_list=gen_1)
        saver_g_2 = tf.train.Saver(var_list=gen_2)
        saver = tf.train.Saver(max_to_keep=10)
        init = tf.global_variables_initializer()
        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(train_dir, sess.graph)
            sess.run(init)
            saver_g_1.restore(sess, tf.train.latest_checkpoint('ckpt_g6_s2o'))
            saver_g_2.restore(sess, tf.train.latest_checkpoint('ckpt_g6_o2s'))
            try:
                shuffle1 = gfm_shuffle(epoch, batch_size, patch_1_train,
                                       patch_2_train, y_train)
                for step, (x_batch, y_batch, l_batch) in enumerate(shuffle1):
                    start_time = time.time()
                    step = step + 1

                    feed_dict = {
                        inputs_sar: x_batch,
                        inputs_opt: y_batch,
                        inputs_lab: l_batch
                    }
                    _, loss, m_output_ = sess.run(
                        [m_train_opt, match_loss, m_output],
                        feed_dict=feed_dict)

                    fake_opt_ = sess.run([fake_opt],
                                         feed_dict={inputs_sar: x_batch})
                    fake_sar_ = sess.run([fake_sar],
                                         feed_dict={inputs_opt: y_batch})
                    fake_opt_ = np.array(fake_opt_, np.float64)[0, :]
                    fake_sar_ = np.array(fake_sar_, np.float64)[0, :]
                    shuffle_index = np.random.permutation(batch_size)
                    shuffle_index = np.array(shuffle_index, np.int32)
                    fake_opt0 = fake_opt_[shuffle_index]
                    fake_sar0 = fake_sar_[shuffle_index]
                    X1 = np.concatenate(
                        (x_batch, x_batch, fake_sar_, fake_sar0), axis=0)
                    Y1 = np.concatenate(
                        (fake_opt_, fake_opt0, y_batch, y_batch), axis=0)
                    L1 = [1] * batch_size + [0] * batch_size + [
                        1
                    ] * batch_size + [0] * batch_size
                    L1 = np.array(L1, np.float64)[:, np.newaxis]

                    shuffle0 = gfm_shuffle(1, batch_size, X1, Y1, L1)
                    for step1, (x_batch, y_batch,
                                l_batch) in enumerate(shuffle0):
                        feed_dict0 = {
                            inputs_sar: x_batch,
                            inputs_opt: y_batch,
                            inputs_lab: l_batch
                        }
                        _, G_loss, m_output_ = sess.run(
                            [m_train_opt, match_loss, m_output],
                            feed_dict=feed_dict0)

                    duration = time.time() - start_time
                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                    if step % 100 == 0:
                        logging.info(
                            '>> Step %d run_train: loss = %.2f G_loss = %.2f (%.3f sec)'
                            % (step, loss, G_loss, duration))

                    if step % 3000 == 0:
                        logging.info('>> %s Saving in %s' %
                                     (datetime.now(), checkpoint_dir))
                        saver.save(sess, checkpoint_file, global_step=step)


#
                    if step % 500 == 0:
                        # test
                        true_count = 0  # Counts the number of correct predictions.
                        num = np.size(y_test)
                        shuffle_test = gfm_shuffle(1, batch_size, patch_1_test,
                                                   patch_2_test, y_test)
                        for step_test, (x_batch, y_batch,
                                        l_batch) in enumerate(shuffle_test):
                            feed_dict = {
                                inputs_sar: x_batch,
                                inputs_opt: y_batch,
                                inputs_lab: l_batch
                            }
                            result, p_out, p_r = sess.run([correct, out, ram],
                                                          feed_dict=feed_dict)

                            true_count = true_count + result
                        precision = float(true_count) / num
                        logging.info(
                            'Num examples: %d  Num correct: %d  Precision : %0.04f'
                            % (num, true_count, precision))

            except KeyboardInterrupt:
                print('INTERRUPTED')

            finally:
                saver.save(sess, checkpoint_file, global_step=step)
                print('Model saved in file :%s' % checkpoint_dir)
def gfm_train():
    
    current_time = datetime.now().strftime('%Y%m%d-%H%M')
    checkpoints_dir = 'checkpoints/{}'.format(current_time)
    try:
        os.makedirs(checkpoint_dir)
        os.makedirs(checkpoints_dir)
    except os.error:
        pass
        
    data1 = np.load('6_up_sift_harris_transform_train_test_data.npz')
    patch_train = data1['arr_0']
    patch_1_train = patch_train[:200000,:,:32,:]  # sar
    patch_2_train = patch_train[:200000,:,32:,:]  # opt
    y_train = data1['arr_2'][:200000,:]
    patch_test = data1['arr_1']
    patch_1_test = patch_test[:3000,:,:32,:]  # sar
    patch_2_test = patch_test[:3000,:,32:,:]  # opt
    y_test = data1['arr_3'][:3000,:]

    graph = tf.Graph()
    with graph.as_default():
        inputs_sar = tf.placeholder(tf.float32, [batch_size, image_height, image_width, 1], name='inputs_sar')
        inputs_opt = tf.placeholder(tf.float32, [batch_size, image_height, image_width, 1], name='inputs_opt')
        inputs_lab = tf.placeholder(tf.float32, [batch_size, 1], name='inputs_lab')
        # 训练 M
        match_loss,m_output = model.gfm_sia_tensor(inputs_sar, inputs_opt, inputs_lab)
        out = tf.round(m_output)
        correct,ram = model.evaluation(out, inputs_lab)
        m_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(match_loss)
        
        tf.summary.scalar('mathing_loss', match_loss)
        summary = tf.summary.merge_all()
        saver = tf.train.Saver(max_to_keep=10)
        init = tf.global_variables_initializer()
        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(train_dir, sess.graph)
            sess.run(init)
            try:
                  shuffle1= gfm_shuffle(epoch,batch_size,patch_1_train,patch_2_train,y_train)
                  for step, (x_batch, y_batch, l_batch) in enumerate(shuffle1):
                        start_time = time.time()
                        step = step + 1
                        
                        feed_dict = {inputs_sar:x_batch, inputs_opt:y_batch, inputs_lab:l_batch}
                        _, m_loss,m_output_, m_out_ = sess.run([m_train_opt, match_loss, m_output, out ], feed_dict = feed_dict)
                        duration = time.time() - start_time
                        summary_str = sess.run(summary, feed_dict=feed_dict)
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()

                        if step % 100 == 0:
                            logging.info('>> Step %d run_train: matching_loss = %.2f (%.3f sec)'
                                          % (step, m_loss, duration))
                            
                        if step % 3000 == 0 :
                            logging.info('>> %s Saving in %s' % (datetime.now(), checkpoint_dir))
                            saver.save(sess, checkpoint_file, global_step=step)
#                            
                        if step % 500 == 0 :
                            # test
                            true_count = 0  # Counts the number of correct predictions.
                            num = np.size(y_test)
                            shuffle_test= gfm_shuffle(1,batch_size,patch_1_test,patch_2_test,y_test)
                            for step_test, (x_batch, y_batch, l_batch) in enumerate(shuffle_test):
                                feed_dict = {inputs_sar:x_batch, inputs_opt:y_batch, inputs_lab:l_batch}
                                result, p_out, p_r = sess.run([correct,out,ram], feed_dict=feed_dict)
                
                                true_count = true_count + result
                            precision = float(true_count) / num
                            logging.info('Num examples: %d  Num correct: %d  Precision : %0.04f' %
                                        (num, true_count, precision))
                        
            except KeyboardInterrupt:
                print('INTERRUPTED')

            finally:
                saver.save(sess, checkpoint_file, global_step=step)
                print('Model saved in file :%s'%checkpoint_dir)
def gfm_train():

    current_time = datetime.now().strftime('%Y%m%d-%H%M')
    checkpoints_dir = 'checkpoints/{}'.format(current_time)
    try:
        os.makedirs(checkpoints_dir)
    except os.error:
        pass

    data1 = np.load('6_up_sift_harris_transform_train_test_data.npz')
    patch_train = data1['arr_0']
    patch_1_train = patch_train[:200000, :, :32, :]  # sar
    patch_2_train = patch_train[:200000, :, 32:, :]  # opt
    y_train = data1['arr_2'][:200000, :]
    patch_test = data1['arr_1']
    patch_1_test = patch_test[:3000, :, :32, :]  # sar
    patch_2_test = patch_test[:3000, :, 32:, :]  # opt
    y_test = data1['arr_3'][:3000, :]

    data2 = np.load('6_up_sift_harris_mapping_data.npy')
    X_test = data2[30000:30100, :, :32, :]
    Y_test = data2[30000:30100, :, 32:, :]

    graph = tf.Graph()
    with graph.as_default():
        inputs_sar = tf.placeholder(tf.float32,
                                    [batch_size, image_height, image_width, 1],
                                    name='inputs_sar')
        inputs_opt = tf.placeholder(tf.float32,
                                    [batch_size, image_height, image_width, 1],
                                    name='inputs_opt')
        inputs_lab = tf.placeholder(tf.float32, [batch_size, 1],
                                    name='inputs_lab')

        g_outputs = model.create_generator(inputs_sar, 1)
        gen_tvars = [
            var for var in tf.trainable_variables()
            if var.name.startswith("generator")
        ]
        # 训练 M
        match_loss, m_output = model.gfm_sia_tensor(g_outputs, inputs_opt,
                                                    inputs_lab)
        out = tf.round(m_output)
        correct, ram = model.evaluation(out, inputs_lab)
        m_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(
            match_loss)

        tf.summary.scalar('mathing_loss', match_loss)
        summary = tf.summary.merge_all()

        saver_g = tf.train.Saver(var_list=gen_tvars)
        saver = tf.train.Saver()
        init = tf.global_variables_initializer()
        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(train_dir, sess.graph)
            sess.run(init)
            saver_g.restore(sess, tf.train.latest_checkpoint(checkpoint_dir_g))

            try:
                shuffle1 = gfm_shuffle(epoch, batch_size, patch_1_train,
                                       patch_2_train, y_train)
                for step, (x_batch, y_batch, l_batch) in enumerate(shuffle1):
                    start_time = time.time()

                    feed_dict = {
                        inputs_sar: x_batch,
                        inputs_opt: y_batch,
                        inputs_lab: l_batch
                    }
                    _, m_loss, m_output_, m_out_ = sess.run(
                        [m_train_opt, match_loss, m_output, out],
                        feed_dict=feed_dict)
                    duration = time.time() - start_time

                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()
                    #
                    if step % 100 == 0:
                        logging.info(
                            '>> Step %d run_train: matching_loss = %.2f (%.3f sec)'
                            % (step, m_loss, duration))

                    if step % 1000 == 0:
                        logging.info('>> %s Saving in %s' %
                                     (datetime.now(), checkpoint_dir))
                        saver.save(sess, checkpoint_file, global_step=step)

                    if step % 500 == 0:
                        # test
                        true_count = 0  # Counts the number of correct predictions.
                        num = np.size(y_test)
                        shuffle_test = gfm_shuffle(1, batch_size, patch_1_test,
                                                   patch_2_test, y_test)
                        for step_test, (x_batch, y_batch,
                                        l_batch) in enumerate(shuffle_test):
                            feed_dict = {
                                inputs_sar: x_batch,
                                inputs_opt: y_batch,
                                inputs_lab: l_batch
                            }
                            result, p_out, p_r = sess.run([correct, out, ram],
                                                          feed_dict=feed_dict)

                            true_count = true_count + result
                        precision = float(true_count) / num
                        logging.info(
                            '  Num examples: %d  Num correct: %d  Precision : %0.04f'
                            % (num, true_count, precision))

                        inputs_sar_test = tf.placeholder(
                            tf.float32, [100, image_height, image_width, 1],
                            name='inputs_sar')
                        g_out_test = model.create_generator(inputs_sar_test,
                                                            1,
                                                            reuse=True)
                        feed_dict = {inputs_sar_test: X_test}
                        g_out_test_result = sess.run(g_out_test,
                                                     feed_dict=feed_dict)

                        show_images = np.concatenate(
                            (X_test, g_out_test_result, Y_test), axis=1)

                        result = combine_images(show_images)
                        result = result * 255
                        misc.imsave(
                            'out6/_sia_tensorg_{}.png'.format(
                                str(epoch) + "_" + str(step)), result)

            except KeyboardInterrupt:
                print('INTERRUPTED')

            finally:
                saver.save(sess, checkpoint_file, global_step=step)
                print('Model saved in file :%s' % checkpoint_dir)