Пример #1
0
def model_train(img_size,
                batch_size,
                epochs,
                optimizer,
                learning_rate,
                train_list,
                validation_list,
                style=2):

    print('Style {}.'.format(style))

    if style == 1:
        input_img = Input(shape=img_size)

        #model = Sequential()

        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal',
                       input_shape=img_size)(input_img)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)

        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)

        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)

        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(1, (3, 3),
                       padding='same',
                       kernel_initializer='he_normal')(model)
        res_img = model

        output_img = merge.Add()([res_img, input_img])

        model = Model(input_img, output_img)

        #model.load_weights('vdsr_model_edges.h5')

        adam = Adam(lr=0.000005)
        #sgd = SGD(lr=1e-3, momentum=0.9, decay=1e-4, nesterov=False)
        sgd = SGD(lr=0.01, momentum=0.9, decay=0.001, nesterov=False)
        #model.compile(sgd, loss='mse', metrics=[PSNR, "accuracy"])
        model.compile(adam,
                      loss='mse',
                      metrics=[ssim, ssim_metric, PSNR, "accuracy"])

        model.summary()

    else:

        input_img = Input(shape=img_size)

        model = Conv2D(64, (3, 3),
                       padding='valid',
                       kernel_initializer='he_normal')(input_img)
        model_0 = Activation('relu')(model)

        total_conv = 22  # should be even number
        total_conv -= 2  # subtract first and last
        residual_block_num = 5  # should be even number

        for _ in range(residual_block_num):  # residual block
            model = Conv2D(64, (3, 3),
                           padding='same',
                           kernel_initializer='he_normal')(model_0)
            model = Activation('relu')(model)
            for _ in range(int(total_conv / residual_block_num) - 1):
                model = Conv2D(64, (3, 3),
                               padding='same',
                               kernel_initializer='he_normal')(model)
                model = Activation('relu')(model)
                model_0 = add([model, model_0])

        model = Conv2D(1, (3, 3),
                       padding='valid',
                       kernel_initializer='he_normal')(model)
        res_img = model

        input_img1 = crop(1, 2, -2)(input_img)
        input_img1 = crop(2, 2, -2)(input_img1)

        print(input_img.shape)
        print(input_img1.shape)
        output_img = merge.Add()([res_img, input_img1])
        # output_img = res_img
        model = Model(input_img, output_img)

        # model.load_weights('./vdsr_model_edges.h5')
        # adam = Adam(lr=learning_rate)
        adam = Adadelta()
        # sgd = SGD(lr=1e-7, momentum=0.9, decay=1e-2, nesterov=False)
        sgd = SGD(lr=learning_rate,
                  momentum=0.9,
                  decay=1e-4,
                  nesterov=False,
                  clipnorm=1)
        if optimizer == 0:
            model.compile(adam, loss='mse', metrics=[ssim, ssim_metric, PSNR])
        else:
            model.compile(sgd, loss='mse', metrics=[ssim, ssim_metric, PSNR])

        model.summary()

    mycallback = MyCallback(model)
    timestamp = time.strftime("%m%d-%H%M", time.localtime(time.time()))
    csv_logger = callbacks.CSVLogger(
        'data/callbacks/training_{}.log'.format(timestamp))
    filepath = "./checkpoints/weights-improvement-{epoch:03d}-{PSNR:.2f}.hdf5"
    checkpoint = ModelCheckpoint(filepath, monitor=PSNR, verbose=1, mode='max')
    callbacks_list = [mycallback, checkpoint, csv_logger]

    # print('Loading training data.')
    # x = load_images(DATA_PATH)
    # print('Loading data label.')
    # y = load_images(LABEL_PATH)
    # print('Loading validation data.')
    # val = load_images(VAL_PATH)
    # print('Loading validation label.')
    # val_label = load_images(VAL_LABEL_PATH)

    # print(x.shape)
    # print(y.shape)
    # print(val.shape)
    # print(val_label.shape)

    with open('./model/vdsr_architecture.json', 'w') as f:
        f.write(model.to_json())

    # datagen = ImageDataGenerator(rotation_range=45,
    #                              zoom_range=0.15,
    #                              horizontal_flip=True,
    #                              vertical_flip=True)

    # history = model.fit_generator(datagen.flow(x, y, batch_size=batch_size),
    #                     steps_per_epoch=len(x) // batch_size,
    #                     validation_data=(val, val_label),
    #                     validation_steps=len(val) // batch_size,
    #                     epochs=epochs,
    #                     callbacks=callbacks_list,
    #                     verbose=1,
    #                     shuffle=True,
    #                     workers=256,
    #                     use_multiprocessing=True)

    history = model.fit_generator(
        image_gen(train_list, batch_size=batch_size),
        steps_per_epoch=384400 * (len(train_list)) // batch_size,
        # steps_per_epoch=4612800//batch_size,
        validation_data=image_gen(validation_list, batch_size=batch_size),
        validation_steps=384400 * (len(validation_list)) // batch_size,
        epochs=epochs,
        workers=1024,
        callbacks=callbacks_list,
        verbose=1)

    print("Done training!!!")

    print("Saving the final model ...")

    model.save('vdsr_model.h5')  # creates a HDF5 file
    del model  # deletes the existing model

    # plt.plot(history.history['accuracy'])
    # plt.plot(history.history['val_accuracy'])
    # plt.title('Model accuracy')
    # plt.ylabel('Accuracy')
    # plt.xlabel('Epoch')
    # plt.legend(['Train', 'validation'], loc='upper left')
    # # plt.show()
    # plt.savefig('accuracy.png')

    # Plot training & validation loss values
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'validation'], loc='upper left')
    # plt.show()
    plt.savefig('loss.png')

    plt.plot(history.history['PSNR'])
    plt.plot(history.history['val_PSNR'])
    plt.title('Model PSNR')
    plt.ylabel('PSNR')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'validation'], loc='upper left')
    # plt.show()
    plt.savefig('PSNR.png')
Пример #2
0
def model_train(img_size, batch_size, epochs, optimizer, learning_rate, train_list, validation_list, style=2):

    print('Style {}.'.format(style))

    if style == 1:
        input_img = Input(shape=img_size)

        #model = Sequential()

        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', input_shape=img_size)(input_img)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)

        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)

        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)

        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        model = Activation('relu')(model)
        model = Conv2D(1, (3, 3), padding='same', kernel_initializer='he_normal')(model)
        res_img = model

        output_img = merge.Add()([res_img, input_img])

        model = Model(input_img, output_img)

        #model.load_weights('vdsr_model_edges.h5')

        adam = Adam(lr=0.000005)
        sgd = SGD(lr=0.01, momentum=0.9, decay=0.001, nesterov=False)
        model.compile(adam, loss='mse', metrics=[ssim, ssim_metric, PSNR, "accuracy"])

        model.summary()

    else:

        input_img = Input(shape=img_size)

        model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', use_bias=False)(input_img)
        model = BatchNormalization()(model)
        model_0 = Activation('relu')(model)

        total_conv = 22  # should be even number
        total_conv -= 2  # subtract first and last
        residual_block_num = 5  # should be even number

        for _ in range(residual_block_num):  # residual block
            model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', use_bias=False)(model_0)
            model = BatchNormalization()(model)
            model = Activation('relu')(model)
            print(_)
            for _ in range(int(total_conv/residual_block_num)-1):
                model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', use_bias=False)(model)
                model = BatchNormalization()(model)
                model = Activation('relu')(model)
                model_0 = add([model, model_0])
                print(_)

        model = Conv2DTranspose(64, (5, 5), padding='valid', kernel_initializer='he_normal', use_bias=False)(model)
        model = BatchNormalization()(model)
        model = LeakyReLU()(model)
        model = Conv2D(1, (5, 5), padding='valid', kernel_initializer='he_normal')(model)
        
        res_img = model

        #input_img1 = crop(1,22,-22)(input_img)
        #input_img1 = crop(2,22,-22)(input_img1)

        print(input_img.shape)
        output_img = merge.Add()([res_img, input_img])
        # output_img = res_img
        model = Model(input_img, output_img)

        # model.load_weights('./vdsr_model_edges.h5')
        # adam = Adam(lr=learning_rate)
        adam = Adadelta()
        sgd = SGD(lr=learning_rate, momentum=0.9, decay=1e-4, nesterov=False, clipnorm=1)
        if optimizer == 0:
            model.compile(adam, loss='mse', metrics=[ssim, ssim_metric, PSNR])
        else:
            model.compile(sgd, loss='mse', metrics=[ssim, ssim_metric, PSNR])


        model.summary()

    mycallback = MyCallback(model)
    timestamp = time.strftime("%m%d-%H%M", time.localtime(time.time()))
    csv_logger = callbacks.CSVLogger('data/callbacks/deconv/training_{}.log'.format(timestamp))
    filepath="./checkpoints/deconv/weights-improvement-{epoch:03d}-{PSNR:.2f}-{ssim:.2f}.hdf5"
    checkpoint = ModelCheckpoint(filepath, monitor=PSNR, verbose=1, mode='max')
    callbacks_list = [mycallback, checkpoint, csv_logger]

    with open('./model/deconv/vdsr_architecture.json', 'w') as f:
        f.write(model.to_json())

    history = model.fit_generator(image_gen(train_list, batch_size=batch_size), 
                        steps_per_epoch=(409600//8)*len(train_list) // batch_size,
                        validation_data=image_gen(validation_list,batch_size=batch_size),
                        validation_steps=(409600//8)*len(validation_list) // batch_size,
                        epochs=epochs,
                        workers=1024,
                        callbacks=callbacks_list,
                        verbose=1)

    print("Done training!!!")

    print("Saving the final model ...")

    model.save('vdsr_model.h5')  # creates a HDF5 file 
    del model  # deletes the existing model


    # Plot training & validation loss values
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'validation'], loc='upper left')
    # plt.show()
    plt.savefig('loss.png')

    plt.plot(history.history['PSNR'])
    plt.plot(history.history['val_PSNR'])
    plt.title('Model PSNR')
    plt.ylabel('PSNR')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'validation'], loc='upper left')
    # plt.show()
    plt.savefig('PSNR.png')
Пример #3
0
model = Activation('relu')(model)
model = Conv2D(1, (3, 3), padding='same',
               kernel_initializer='he_normal')(model)
res_img = model

output_img = add([res_img, input_img])

model = Model(input_img, output_img)

# model.load_weights('./checkpoints/weights-improvement-20-26.93.hdf5')

adam = Adam(lr=0.00001)
sgd = SGD(lr=1e-5, momentum=0.9, decay=1e-4, nesterov=False)
model.compile(adam, loss='mse', metrics=[PSNR, "accuracy"])

model.summary()

filepath = "./checkpoints/weights-improvement-{epoch:02d}-{PSNR:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor=PSNR, verbose=1, mode='max')
callbacks_list = [checkpoint]

model.fit_generator(image_gen(train_list), steps_per_epoch=len(train_list) // BATCH_SIZE,  \
     validation_data=image_gen(test_list), validation_steps=len(train_list) // BATCH_SIZE, \
     epochs=EPOCHS, workers=8, callbacks=callbacks_list)

print("Done training!!!")

print("Saving the final model ...")

model.save('vdsr_model.h5')  # creates a HDF5 file
del model  # deletes the existing model
Пример #4
0
model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(model)
model = Activation('relu')(model)
model = Conv2D(1, (3, 3), padding='same', kernel_initializer='he_normal')(model)
res_img = model

output_img = add([res_img, input_img])

model = Model(input_img, output_img)

# model.load_weights('./checkpoints/weights-improvement-20-26.93.hdf5')

adam = Adam(lr=0.00001)
sgd = SGD(lr=1e-5, momentum=0.9, decay=1e-4, nesterov=False)
model.compile(adam, loss='mse', metrics=[PSNR, "accuracy"])

model.summary()

filepath="./checkpoints/weights-improvement-{epoch:02d}-{PSNR:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor=PSNR, verbose=1, mode='max')
callbacks_list = [checkpoint]

model.fit_generator(image_gen(train_list), steps_per_epoch=len(train_list) // BATCH_SIZE,  \
					validation_data=image_gen(test_list), validation_steps=len(train_list) // BATCH_SIZE, \
					epochs=EPOCHS, workers=8, callbacks=callbacks_list)

print("Done training!!!")

print("Saving the final model ...")

model.save('vdsr_model.h5')  # creates a HDF5 file 
del model  # deletes the existing model