def train_kfold(idx, kfold, datapath, labelpath, epochs, batch_size, lr, finetune): sess = tf.Session() K.set_session(sess) model, y_func = get_model((*SIZE, 3), training=True, finetune=finetune) ada = Adam(lr=lr) model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=ada) ## load data train_idx, valid_idx = kfold[idx] train_generator = TextImageGenerator(datapath, labelpath, *SIZE, batch_size, 32, train_idx, True, MAX_LEN) train_generator.build_data() valid_generator = TextImageGenerator(datapath, labelpath, *SIZE, batch_size, 32, valid_idx, False, MAX_LEN) valid_generator.build_data() ## callbacks weight_path = 'model/best_%d.h5' % idx ckp = ModelCheckpoint(weight_path, monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True) vis = VizCallback(sess, y_func, valid_generator, len(valid_idx)) earlystop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=0, mode='min') if finetune: print('load pretrain model') model.load_weights(weight_path) model.fit_generator(generator=train_generator.next_batch(), steps_per_epoch=int(len(train_idx) / batch_size), epochs=epochs, callbacks=[ckp, vis, earlystop], validation_data=valid_generator.next_batch(), validation_steps=int(len(valid_idx) / batch_size))
def train( datapath, labelpath, epochs, batch_size, lr, finetune): sess = tf.Session() K.set_session(sess) model, y_func = get_model((*SIZE, 3), training=True, finetune=finetune) ada = Adam(lr=lr,clipvalue=5) model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=ada) ## load data id = np.arange(len(os.listdir(datapath))) train_idx, valid_idx = train_test_split(id,test_size=0.05, random_state=42) train_generator = TextImageGenerator(datapath, labelpath, *SIZE, batch_size, 8, train_idx, True, MAX_LEN) #train_generator.build_data() valid_generator = TextImageGenerator(datapath, labelpath, *SIZE, batch_size, 8, valid_idx, False, MAX_LEN) #valid_generator.build_data() ## callbacks weight_path = 'model/best_weight.h5' ckp = ModelCheckpoint(weight_path, monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True) vis = VizCallback(sess, y_func, valid_generator, len(valid_idx)) earlystop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=100, verbose=0, mode='min') if finetune: print('load pretrain model') model.load_weights(weight_path) # for layer in model.layers: # if layer.name == 'cnn': # layer.trainable = False # model.summary() model.fit_generator(generator=train_generator.next_batch(), steps_per_epoch=int(len(train_idx) / batch_size), epochs=epochs, callbacks=[ckp, vis, earlystop], validation_data=valid_generator.next_batch(), validation_steps=int(len(valid_idx) / batch_size))