Пример #1
0
输出预测图像并存起来
'''
from model import RDN
import skimage.io
import glob
import os
import numpy as np
from train import modelsavedir
if __name__ == '__main__':
    #xdir='../Data/Set5_test_LR'
    xdir = '../Data/Urban_test_LR'
    xlist = glob.glob(os.path.join(xdir, '*.png'))
    myRDN = RDN()
    #save_dir='./result'
    save_dir = './result2'
    modellist = glob.glob(modelsavedir + '*RDN*.hdf5')
    if len(modellist) > 0:
        modellist.sort(key=lambda x: float(x[len(modelsavedir) + 9:len(
            modelsavedir) + 13]))
        model = myRDN.load_weight(modellist[0])
        print('载入', modellist[0])
    else:
        model = myRDN.load_weight()
    # 读取图像
    for imgname in xlist:
        print(imgname)
        img = skimage.io.imread(imgname)
        Y = model.predict(np.array([img]), 1)[0]
        Y = np.clip(Y, 0, 255)
        Y = Y.astype(np.uint8)
        skimage.io.imsave(os.path.join(save_dir, os.path.basename(imgname)), Y)
Пример #2
0
    modelcp = ModelCheckpoint(modelsavedir +
                              '{epoch:04d}-RDN-{val_loss:.2f}-weights.hdf5',
                              verbose=1,
                              period=1,
                              save_weights_only=True,
                              save_best_only=True)
    gen_tx, gen_ty, gen_vx, gen_vy = get_train_data(Batch_size)
    train_gen = gen(gen_tx, gen_ty)
    valid_gen = gen(gen_vx, gen_vy)
    myrdn.setting_train()

    # 载入之前的模型
    modellist = glob.glob(modelsavedir + '*RDN*.hdf5')
    modellist.sort(
        key=lambda x: float(x[len(modelsavedir) + 0:len(modelsavedir) + 4]))
    myrdn.load_weight(modellist[-1])
    print('载入', modellist[-1])
    init_epoch = int(os.path.basename(modellist[-1])[:4])
    target_epoch = 170
    step = 400
    print('目标', target_epoch, '还有', target_epoch - init_epoch, '\n时间(分)',
          3 / (500 * 24) * step * Batch_size * (target_epoch - init_epoch))

    try:
        myrdn.model.fit_generator(train_gen,
                                  step,
                                  epochs=target_epoch,
                                  verbose=1,
                                  validation_data=valid_gen,
                                  validation_steps=6,
                                  callbacks=[lr_decay, modelcp],