Exemplo n.º 1
0
def generateVerticationSet(train_path, verti_path, num):
    '''
    Randomly select 20% as the verification set
    :param train_path:
    :param verti_path:
    :param num: the number of the patches
    :return:
    '''
    ratio = 0.2
    verti_num = roundNum(num * ratio)
    num_list = []
    checkFile(verti_path)
    generateRandomList(num_list, num, verti_num)
    print(num_list)
    for ind, val in enumerate(num_list):
        # mat = sio.loadmat(train_path+'%d.mat'%val)
        # sio.savemat(verti_path+'%d.mat'%(ind+1),mat)
        try:
            shutil.copy(train_path + '%d.mat' % val,
                        verti_path + '%d.mat' % (ind + 1))
            os.remove(train_path + '%d.mat' % val)
            print('%d has created' % (ind + 1))
        except:
            print('raise error')
    print('veticatication set created')
    print('do rerank train set')
    # rename the left train pieces
    reRankfile(train_path, 'a')
    reRankfile(train_path, '')
    return num - verti_num, verti_num
Exemplo n.º 2
0
def readCAVEData(original_data_path, mat_save_path):
    '''
    Read initial CAVE data
    since the original data is standardized we do not repeat it
    :param original_data_path:
    :param mat_save_path:
    :return:
    '''
    path = original_data_path
    hsi = np.zeros([512, 512, 31], dtype=np.float32)
    mat_path = mat_save_path
    checkFile(mat_path)
    count = 0
    for dir in os.listdir(path):
        concrete_path = path + '/' + dir + '/' + dir
        for i in range(31):
            fix = str(i + 1)
            if i + 1 < 10:
                fix = '0' + str(i + 1)
            png_path = concrete_path + '/' + dir + '_' + fix + '.png'
            try:
                hsi[:, :, i] = plt.imread(png_path)
            except:
                img = plt.imread(png_path)
                img = img[:, :, :3]
                img = np.mean(img, axis=2)
                hsi[:, :, i] = img

        count += 1
        print('%d has finished' % count)
        sio.savemat(mat_path + str(count) + '.mat', {'X': hsi})
Exemplo n.º 3
0
def test():
    print('predict the results with well_trained CNN-----')
    network = tonwmd(FLAGS.channel,
                     FLAGS.ms_channel,
                     FLAGS.train_data_path,
                     FLAGS.vertication_data_path,
                     FLAGS.model_save_path,
                     FLAGS.mat_save_path,
                     FLAGS.cnn_output_save_path,
                     FLAGS.training_num,
                     FLAGS.vertication_num,
                     train_batch_size=FLAGS.train_batch_size,
                     valid_batch_size=FLAGS.valid_batch_size,
                     piece_size=FLAGS.piece_size,
                     ratio=FLAGS.ratio,
                     maxpower=FLAGS.max_power,
                     test_start=FLAGS.test_start,
                     test_end=FLAGS.test_end,
                     test_height=FLAGS.height,
                     test_width=FLAGS.width)
    network.test()

    FLAGS.lamb = 2e-3
    FLAGS.mu = 2e-3
    if FLAGS.isEstimate:
        # if using the estimated B and R, the algorithm is provided by HySure
        R = sio.loadmat('B_R/R.mat')['R']
        B = sio.loadmat('B_R/B.mat')['B']
        B = np.fft.fft2(B)
    else:
        B = get_kernal(FLAGS.kernal_size, FLAGS.sigma, FLAGS.height,
                       FLAGS.width)
        R = getSpectralResponse()  # come from the nikon camera

    checkFile(FLAGS.output_save_path)
    print('Post-optimization to further improve the performance-------')
    # We use the post-optimization to further improve the performance
    for i in range(FLAGS.test_start, FLAGS.test_end + 1):
        mat = sio.loadmat(FLAGS.mat_save_path + '%d.mat' % i)
        mat2 = sio.loadmat(FLAGS.cnn_output_save_path + '%d.mat' % i)
        Y = mat['Y']
        Z = mat['Z']
        Xcnn = mat2['XCNN']
        # Xes = twice_optimization(Yup, Y, Z, B, R)
        F = twice_optimization_with_estBR(Xcnn,
                                          Y,
                                          Z,
                                          B,
                                          R,
                                          k=FLAGS.channel,
                                          ratio=FLAGS.ratio,
                                          lamb=FLAGS.lamb,
                                          mu=FLAGS.mu)
        mat['F'] = F
        sio.savemat(FLAGS.output_save_path + str(i) + '.mat', mat)
        print('F %d has finished' % i)
    print('quality_evaluate-----')
    quality_evaluate()
Exemplo n.º 4
0
    def __init__(self,
                 channel,
                 mschannel,
                 train_pieces_path,
                 valid_pieces_path,
                 model_save_path,
                 test_data_label_path,
                 output_save_path,
                 total_num,
                 valid_num,
                 train_batch_size=64,
                 valid_batch_size=16,
                 piece_size=32,
                 ratio=8,
                 maxpower=20,
                 test_start=21,
                 test_end=32,
                 test_height=512,
                 test_width=512):
        # self.choose_dataset(num)

        self.setDataAbout(channel, mschannel, train_pieces_path,
                          valid_pieces_path, model_save_path,
                          test_data_label_path, output_save_path, total_num,
                          valid_num, train_batch_size, valid_batch_size,
                          piece_size, ratio, maxpower, test_start, test_end,
                          test_height, test_width)

        self.data = tf.placeholder(dtype=tf.float32,
                                   shape=[None, None, None, self.channels],
                                   name='Xes')
        self.label = tf.placeholder(dtype=tf.float32,
                                    shape=[None, None, None, self.channels],
                                    name='Xrel')
        self.istraining = tf.placeholder(dtype=tf.bool,
                                         shape=[],
                                         name='istraining')
        self.learning_rate = tf.placeholder(dtype=tf.float32,
                                            shape=[],
                                            name='learning_rate')
        self.running_size = tf.placeholder(dtype=tf.int32,
                                           shape=[],
                                           name='running_size')

        self.weightlist = []
        self.resultlist = []
        self.helplist = []
        self.losslist1 = []
        self.losslist2 = []

        checkFile(self.model_save_path)
Exemplo n.º 5
0
def cutCAVEPieces(mat_save_path,
                  piece_save_path,
                  piece_size=32,
                  stride=16,
                  num_end=20):
    '''
    cutting CAVE(first 20 images) into pieces
    :param mat_save_path:
    :param piece_save_path:
    :param piece_size:
    :param stride:
    :param num:
    :return:
    '''
    rows, cols = 512, 512
    num_start = 1
    # num_end = 20
    mat_path = mat_save_path
    piece_save_path = piece_save_path
    count = 0
    checkFile(piece_save_path)
    for i in range(num_start, num_end + 1):
        mat = sio.loadmat(mat_path + '%d.mat' % i)
        X = mat['X']
        Xin = mat['XES']
        for x in range(0, rows - piece_size + stride, stride):
            for y in range(0, cols - piece_size + stride, stride):
                data_piece = Xin[x:x + piece_size, y:y + piece_size, :]
                label_piece = X[x:x + piece_size, y:y + piece_size, :]
                count += 1
                sio.savemat(piece_save_path + '%d.mat' % count, {
                    'data': data_piece,
                    'label': label_piece
                })
                print('piece num %d has saved' % count)
        print('%d has finished' % i)
    print('done')
    return count
Exemplo n.º 6
0
    def test(self):
        start = time.perf_counter()
        self.initGpu()
        self.buildGraph()
        self.saver = tf.train.Saver()
        psnr = 0

        b = self.train_batch_size
        h = self.test_height
        w = self.test_width

        # the orignal image is too large to be taken into the model directly, we divide it and then group
        piece_size = 32
        # the other are cut into 32*32 as the same size with trainging patches
        piece_count = (h // piece_size) * (w // piece_size)
        input_pieces = np.zeros(
            [piece_count, piece_size, piece_size, self.channels],
            dtype=np.float32)
        checkFile(self.output_save_path)
        test_start = self.test_start
        test_end = self.test_end

        with self.session as sess:
            latest_model = tf.train.get_checkpoint_state(self.model_save_path)
            self.saver.restore(sess, latest_model.model_checkpoint_path)
            # self.saver.restore(sess,self.model_save_path+'-103')
            for i in range(test_start, test_end + 1):
                mat = sio.loadmat(self.test_data_label_path + '%d.mat' % i)
                data = mat['XES']
                X = mat['X']
                self.helplist.clear()
                count = 0
                icount = 0
                for x in range(0, h, piece_size):
                    for y in range(0, w, piece_size):
                        input_pieces[count, :, :, :] = data[x:x + piece_size,
                                                            y:y +
                                                            piece_size, :]
                        # input_pieces2[count, :, :, :] = Z[x:x + piece_size, y:y + piece_size, :]
                        count += 1
                while count >= b:
                    output = sess.run(
                        self.output,
                        feed_dict={
                            self.data:
                            input_pieces[icount * b:icount * b + b, :, :, :],
                            # self.ms_data: input_pieces2[icount * b:icount * b + b, :, :, :],
                            self.istraining:
                            False,
                            self.running_size:
                            piece_size
                        })
                    self.helplist.append(output)
                    count -= b
                    icount += 1
                if count > 0:
                    output = sess.run(
                        self.output,
                        feed_dict={
                            self.data:
                            input_pieces[icount * b:icount * b +
                                         count, :, :, :],
                            # self.ms_data: input_pieces2[icount * b:icount * b + count, :, :, :],
                            self.istraining:
                            False,
                            self.running_size:
                            piece_size
                        })
                    self.helplist.append(output)
                input_pieces = np.concatenate(self.helplist, axis=0)
                count = 0
                for x in range(0, h, piece_size):
                    for y in range(0, w, piece_size):
                        data[x:x + piece_size, y:y +
                             piece_size, :] = input_pieces[count, :, :, :]
                        count += 1

                output = data
                output[output < 0] = 0
                output[output > 1] = 1.0
                sio.savemat(self.output_save_path + '%d.mat' % i,
                            {'XCNN': output})
                # rgb = spectralDegrade(output, R)
                # plt.imshow(rgb)
                # plt.show()
                # psnr += compare_psnr(Z, rgb)
                psnr += compare_psnr(X, output)
                print('%d has finished' % i)
            print(psnr / (test_end - test_start + 1))
            print((test_end - test_start + 1))
            end = time.perf_counter()
            print('用时%ss' % ((end - start) / (test_end - test_start + 1)))
Exemplo n.º 7
0
 def buildSaver(self):
     self.saver = tf.train.Saver(max_to_keep=100)
     checkFile(self.model_save_path)