示例#1
0
def train_kfold(idx, kfold, datapath, labelpath,  epochs, batch_size, lr, finetune):
    sess = tf.Session()
    K.set_session(sess)

    model, y_func = get_model((*SIZE, 3), training=True, finetune=finetune)
    ada = Adam(lr=lr)
    model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=ada)

    ## load data
    train_idx, valid_idx = kfold[idx]
    train_generator = TextImageGenerator(datapath, labelpath, *SIZE, batch_size, 32, train_idx, True, MAX_LEN)
    train_generator.build_data()
    valid_generator  = TextImageGenerator(datapath, labelpath, *SIZE, batch_size, 32, valid_idx, False, MAX_LEN)
    valid_generator.build_data()

    ## callbacks
    weight_path = 'model/best_%d.h5' % idx
    ckp = ModelCheckpoint(weight_path, monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True)
    vis = VizCallback(sess, y_func, valid_generator, len(valid_idx))
    earlystop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=0, mode='min')

    if finetune:
        print('load pretrain model')
        model.load_weights(weight_path)

    model.fit_generator(generator=train_generator.next_batch(),
                    steps_per_epoch=int(len(train_idx) / batch_size),
                    epochs=epochs,
                    callbacks=[ckp, vis, earlystop],
                    validation_data=valid_generator.next_batch(),
                    validation_steps=int(len(valid_idx) / batch_size))
示例#2
0
def train( datapath, labelpath,  epochs, batch_size, lr, finetune):
    sess = tf.Session()
    K.set_session(sess)

    model, y_func = get_model((*SIZE, 3), training=True, finetune=finetune)
    ada = Adam(lr=lr,clipvalue=5)
    model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=ada)

    ## load data
    id =  np.arange(len(os.listdir(datapath)))
    train_idx, valid_idx = train_test_split(id,test_size=0.05, random_state=42)
    train_generator = TextImageGenerator(datapath, labelpath, *SIZE, batch_size, 8, train_idx, True, MAX_LEN)
    #train_generator.build_data()
    valid_generator  = TextImageGenerator(datapath, labelpath, *SIZE, batch_size, 8, valid_idx, False, MAX_LEN)
    #valid_generator.build_data()

    ## callbacks
    weight_path = 'model/best_weight.h5'
    ckp = ModelCheckpoint(weight_path, monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True)
    vis = VizCallback(sess, y_func, valid_generator, len(valid_idx))
    earlystop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=100, verbose=0, mode='min')

    if finetune:
        print('load pretrain model')
        model.load_weights(weight_path)
    
    # for layer in model.layers:
    #     if layer.name == 'cnn':
    #         layer.trainable = False
    #         model.summary()
    model.fit_generator(generator=train_generator.next_batch(),
                    steps_per_epoch=int(len(train_idx) / batch_size),
                    epochs=epochs,
                    callbacks=[ckp, vis, earlystop],
                    validation_data=valid_generator.next_batch(),
                    validation_steps=int(len(valid_idx) / batch_size))