Beispiel #1
0
def train_model():
    # calc num_classes
    key_f = open(params.char_path, 'r', encoding='utf-8')
    chars = key_f.read()
    key_f.close()
    params.num_classes = len(chars) + 1
    print('params.num_classes: ', params.num_classes)

    template_model = crnn_model.get_Model(training=True)

    try:
        latest_weights = params.load_weights_path
        print("find latest_weights exists.", latest_weights)
        if latest_weights != None:
            template_model.load_weights(latest_weights)
            print("...load exist weights: ", latest_weights)
        else:
            print("history weights file not exist, train a new one.")
    except Exception as e:
        print('warn: ', str(e))
        print("historical weights data can not be used, train a new one...")
        pass

    model = multi_gpu_model(template_model, gpus=params.gpu_nums_in_multi_model)
    
    model.layers[-2].set_weights(template_model.get_weights())

    ada = Adadelta()

    # the loss calc occurs elsewhere, so use a dummy lambda func for the loss
    model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=ada)

    early_stop = EarlyStopping(monitor='val_loss',
                               min_delta=0.001,
                               patience=6,
                               mode='min',
                               verbose=1)
                                 
    my_checkpoint = MyCheckPoint(template_model)    
    
    tensor_board = TensorBoard(log_dir='/data/output')

    train_data_gen, val_data_gen, train_sample_num, val_sample_num = load_train_and_val_data()

    batch_size = params.batch_size
    epoch_num = params.epoch_num
    val_batch_size = params.val_batch_size

    for ep_i in range(epoch_num):
        print("epoch: ", ep_i+1)
        model.fit_generator(generator=train_data_gen,
                            steps_per_epoch=train_sample_num // batch_size,
                            epochs=ep_i + 1,
                            callbacks=[],
                            verbose=2,
                            initial_epoch=ep_i,
                            validation_data=val_data_gen,
                            validation_steps=val_sample_num // val_batch_size)

        curr_weights_path = "/data/output/crnn_weights_20190106/crnn_weights_v1.13_ep_%d.h5" % (ep_i + 1)
        template_model.save_weights(curr_weights_path)
        train_data_acc = epoch_eval.eval_on_generating_data(curr_weights_path)
        print(" -- train_data_acc: ", train_data_acc)
        real_data_acc, detail_info = epoch_eval.eval_on_real_data(curr_weights_path)
        print(" -- real_data_acc: ", real_data_acc, detail_info)
Beispiel #2
0
def train_model():
    # calc num_classes
    key_f = open(params.char_path, 'r', encoding='utf-8')
    chars = key_f.read()
    key_f.close()
    params.num_classes = len(chars) + 1
    print('params.num_classes: ', params.num_classes)

    model = crnn_model.get_Model(training=True)
    try:
        latest_weights = params.load_weights_path
        print("find latest_weights exists.", latest_weights)
        if latest_weights != None:
            model.load_weights(latest_weights)
            print("...load exist weights: ", latest_weights)
        else:
            print("history weights file not exist, train a new one.")
    except Exception as e:
        print('warn: ', str(e))
        print("historical weights data can not be used, train a new one...")
        pass

    train_data_gen, val_data_gen, train_sample_num, val_sample_num = load_train_and_val_data(
    )

    ada = Adadelta()

    early_stop = EarlyStopping(monitor='val_loss',
                               min_delta=0.001,
                               patience=8,
                               mode='min',
                               verbose=1)
    checkpoint = ModelCheckpoint(
        filepath='/data/output/CRNN--{epoch:02d}--{val_loss:.3f}.h5',
        monitor='val_loss',
        save_best_only=False,
        save_weights_only=True,
        verbose=1,
        mode='min',
        period=1)
    tensor_board = TensorBoard(log_dir='/data/output')
    # the loss calc occurs elsewhere, so use a dummy lambda func for the loss
    model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=ada)

    # captures output of softmax so we can decode the output during visualization

    batch_size = params.batch_size
    epoch_num = params.epoch_num
    val_batch_size = params.val_batch_size

    for ep_i in range(epoch_num):
        print("epoch: ", ep_i + 1)
        model.fit_generator(generator=train_data_gen,
                            steps_per_epoch=train_sample_num // batch_size,
                            epochs=ep_i + 1,
                            callbacks=[],
                            verbose=1,
                            initial_epoch=ep_i,
                            validation_data=val_data_gen,
                            validation_steps=val_sample_num // val_batch_size)

        curr_weights_path = "/data/output/crnn_ticket_20190327/crnn_weights_d10w_ticket_id_date_20190327_ep_%d.h5" % (
            ep_i + 1)
        model.save_weights(curr_weights_path)
        train_data_acc = epoch_eval.eval_on_generating_data(curr_weights_path)
        print(" -- train_data_acc: ", train_data_acc)
        real_data_acc, detail_info = epoch_eval.eval_on_real_data(
            curr_weights_path)
        print(" -- real_data_acc: ", real_data_acc, detail_info)
Beispiel #3
0
from batch_test import epoch_eval
import os

if __name__ == '__main__':
    print('start doing batch_testing...')
    weights_dir = '/data/output/crnn_ticket_20190327'

    #weights_fn_list = ['crnn_weights_d10w_20190325_effects_ep_3.h5']
    weights_fn_list = os.listdir(weights_dir)
    weights_fn_list = sorted(weights_fn_list)
    # weights_fn_list = ['crnn_weights_d10w_ep_%d.h5'%i for i in range(1,20)]
    #weights_fn_list = weights_fn_list[:1]
    for idx, weights_fn in enumerate(weights_fn_list):
        print('--------------------------------------')
        print(idx, weights_fn)
        weights_path = os.path.join(weights_dir, weights_fn)
        #acc = epoch_eval.eval_on_generating_data(weights_path)
        #print('acc', acc)
        acc, test_res = epoch_eval.eval_on_real_data(weights_path)
        print('test_res: ', test_res)