def pred(args):
    charset = label_utils.get_charset(conf.CHARSET)
    CHARSET_SIZE = len(charset)

    # 定义模型
    _,decoder_model,encoder_model = _model.model(conf,args)

    # 分别加载模型
    encoder_model.load_model(args.model)
    decoder_model.load_model(args.model)
    logger.info("加载了模型:%s", args.model)

    logger.info("开始预测图片:%s",args.image)
    image = cv2.imread(args.image)


    # 编码器先预测
    encoder_out_states, encoder_fwd_state, encoder_back_state = encoder_model.predict(image)

    # 准备编码器的初始输入状态
    decoder_init_state = np.concatenate([encoder_fwd_state, encoder_back_state], axis=-1)

    attention_weights = []

    # 开始是STX
    from utils.label_utils import convert_to_id
    decoder_index = convert_to_id([conf.CHAR_STX], charset)
    decoder_state = decoder_init_state

    result = ""

    # 开始预测字符
    for i in range(conf.MAX_SEQUENCE):

        # 别看又padding啥的,其实就是一个字符,这样做是为了凑输入的维度定义
        decoder_inputs = pad_sequences(decoder_index,maxlen=conf.MAX_SEQUENCE,padding="post",value=0)
        decoder_inputs = to_categorical(decoder_inputs,num_classes=CHARSET_SIZE)

        # infer_decoder_model : Model(inputs=[decoder_inputs, encoder_out_states,decoder_init_state],
        # outputs=[decoder_pred,attn_states,decoder_state])
        # encoder_out_states->attention用
        decoder_out, attention, decoder_state = \
            decoder_model.predict([decoder_inputs,encoder_out_states,decoder_state])

        # 得到当前时间的输出,是一个3770的概率分布,所以要argmax,得到一个id
        decoder_index = np.argmax(decoder_out, axis=-1)[0, 0]

        if decoder_index == 2:
            logger.info("预测字符为ETX,退出")
            break #==>conf.CHAR_ETX: break

        attention_weights.append(attention)

        pred_char = label_utils.id2str(decoder_index,charset=charset)

        logger.info("预测字符为:%s",pred_char)
        result+= pred_char

    return pred_char,attention_weights
Exemple #2
0
 def __init__(self, name,label_file, charset_file,conf,args,batch_size=32):
     self.conf = conf
     self.name = name
     self.label_file = label_file
     self.batch_size = batch_size
     self.charsets = label_utils.get_charset(charset_file)
     self.initialize(conf,args)
     self.start_time = time.time()
Exemple #3
0
def pred(args):
    charset = label_utils.get_charset(conf.CHARSET)
    CHARSET_SIZE = len(charset)

    # 加载模型
    model = load_model(args.model, custom_objects={
        'words_accuracy': _model.words_accuracy,
        'squeeze_wrapper': Conv().squeeze_wrapper,
        'AttentionLayer':AttentionLayer})
    encoder_model,decoder_model = _model.infer_model(model,conf)
    logger.info("加载了模型:%s", args.model)

    logger.info("开始预测图片:%s",args.image)
    image = cv2.imread(args.image)

    # 编码器先预测
    encoder_out_states, encoder_state = encoder_model.predict(np.array([image]))

    # 准备编码器的初始输入状态
    attention_weights = []

    # 开始是STX
    from utils.label_utils import convert_to_id
    decoder_index = convert_to_id([conf.CHAR_STX], charset)
    decoder_state = encoder_state

    result = ""

    # 解码器解码,开始一个一个地预测字符
    for i in range(conf.MAX_SEQUENCE):

        # 别看又padding啥的,其实就是一个字符,这样做是为了凑输入的维度定义
        decoder_inputs = to_categorical(decoder_index,num_classes=CHARSET_SIZE)

        # 只解码一个字符,decoder_state被更新
        decoder_out, attention,decoder_state = decoder_model.predict([[decoder_inputs],encoder_out_states,decoder_state])

        # decoder_out[1,1,3770] =argmax=> [[max_id]]
        decoder_index = decoder_out.argmax(axis=2)
        decoder_index = decoder_index[0]
        pred_char = label_utils.id2str(decoder_index, charset)
        if pred_char == conf.CHAR_ETX:
            logger.info("预测字符为ETX,退出")
            break #==>conf.CHAR_ETX: break

        attention_weights.append(attention)

        logger.info("预测字符ID[%d],对应字符[%s]",decoder_index[0],pred_char)
        result+= pred_char
        decoder_index = [decoder_index]

    if len(result)>=conf.MAX_SEQUENCE:
        logger.debug("预测字符为:%s,达到最大预测长度", result)
    else:
        logger.debug("预测字符为:%s,解码最后为ETX", result)

    return pred_char,attention_weights
Exemple #4
0
        # decoder_out[1,1,3770] =argmax=> [[max_id]]
        decoder_index = decoder_out.argmax(axis=2)
        decoder_index = decoder_index[0]
        pred_char = label_utils.id2str(decoder_index, charset)
        if pred_char == conf.CHAR_ETX:
            logger.info("预测字符为ETX,退出")
            break #==>conf.CHAR_ETX: break

        attention_weights.append(attention)

        logger.info("预测字符ID[%d],对应字符[%s]",decoder_index[0],pred_char)
        result+= pred_char
        decoder_index = [decoder_index]

    if len(result)>=conf.MAX_SEQUENCE:
        logger.debug("预测字符为:%s,达到最大预测长度", result)
    else:
        logger.debug("预测字符为:%s,解码最后为ETX", result)

    return pred_char,attention_weights


if __name__ == "__main__":
    log.init()
    charset = label_utils.get_charset(conf.CHARSET)
    conf.CHARSET_SIZE = len(charset)
    args = conf.init_pred_args()
    result,attention_probs = pred(args)
    logger.info("预测字符串为:%s",result)
    # logger.info("注意力概率为:%r", attention_probs)
Exemple #5
0
def train(args):
    # TF调试代码 for tf debugging:
    # from tensorflow.python import debug as tf_debug
    # from tensorflow.python.keras import backend as K
    # sess = K.get_session()
    # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
    # K.set_session(sess)

    charset = label_utils.get_charset(conf.CHARSET)
    conf.CHARSET_SIZE = len(charset)

    model= _model.train_model(conf, args)

    train_sequence = SequenceData(name="训练",
                                  label_file=args.train_label_file,
                                  charset_file=conf.CHARSET,
                                  conf=conf,
                                  args=args,
                                  batch_size=args.batch)
    valid_sequence = SequenceData(name="验证",
                                  label_file=args.validate_label_file,
                                  charset_file=conf.CHARSET,
                                  conf=conf,
                                  args=args,
                                  batch_size=args.validation_batch)

    timestamp = util.timestamp_s()
    tb_log_name = os.path.join(conf.DIR_TBOARD,timestamp)
    checkpoint_path = conf.DIR_MODEL + "/model-" + timestamp + "-epoch{epoch:03d}-acc{words_accuracy:.4f}-val{val_words_accuracy:.4f}.hdf5"

    # 如果checkpoint文件存在,就加载之
    if args.retrain:
        logger.info("重新开始训练....")
    else:
        logger.info("基于之前的checkpoint训练...")
        _checkpoint_path = util.get_checkpoint(conf.DIR_CHECKPOINT)
        if _checkpoint_path is not None:
            model = load_model(_checkpoint_path,
                custom_objects={'words_accuracy': _model.words_accuracy})
            logger.info("加载checkpoint模型[%s]", _checkpoint_path)
        else:
            logger.warning("找不到任何checkpoint,重新开始训练")

    logger.info("Begin train开始训练:")

    attention_visible = TBoardVisual('Attetnon Visibility',tb_log_name,charset,args)
    tboard = TensorBoard(log_dir=tb_log_name,histogram_freq=1,batch_size=2)#,write_grads=True),写梯度会非常慢
    early_stop = EarlyStopping(monitor='words_accuracy', patience=args.early_stop, verbose=1, mode='max')
    checkpoint = ModelCheckpoint(filepath=checkpoint_path, monitor='words_accuracy', verbose=1, mode='max')

    model.fit_generator(
        generator=train_sequence,
        steps_per_epoch=args.steps_per_epoch,#其实应该是用len(train_sequence),但是这样太慢了,所以,我规定用一个比较小的数,比如1000
        epochs=args.epochs,
        workers=args.workers,   # 同时启动多少个进程加载
        callbacks=[tboard,checkpoint,early_stop,attention_visible],
        use_multiprocessing=True,
        validation_data=valid_sequence,
        validation_steps=args.validation_steps)
    # [validation_steps](https://keras.io/zh/models/model/):
    # 对于 Sequence,它是可选的:如果未指定,将使用 len(generator) 作为步数。

    logger.info("Train end训练结束!")

    model_path = conf.DIR_MODEL+"/ocr-attention-{}.hdf5".format(util.timestamp_s())
    model.save(model_path)
    logger.info("Save model保存训练后的模型到:%s", model_path)
def train(args):
    # TF调试代码 for tf debugging:
    # from tensorflow.python import debug as tf_debug
    # from tensorflow.python.keras import backend as K
    # sess = K.get_session()
    # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
    # K.set_session(sess)

    charset = label_utils.get_charset(conf.CHARSET)
    conf.CHARSET_SIZE = len(charset)

    model, _, _ = _model.model(conf, args)
    # K.get_session().run(tf.global_variables_initializer())
    train_sequence = SequenceData(name="训练",
                                  label_file=args.train_label_file,
                                  charset_file=conf.CHARSET,
                                  conf=conf,
                                  args=args,
                                  batch_size=args.batch)
    valid_sequence = SequenceData(name="验证",
                                  label_file=args.validate_label_file,
                                  charset_file=conf.CHARSET,
                                  conf=conf,
                                  args=args,
                                  batch_size=args.validation_batch)

    timestamp = util.timestamp_s()
    tb_log_name = conf.DIR_TBOARD+"/"+timestamp
    checkpoint_path = conf.DIR_CHECKPOINT+"/checkpoint-{}.hdf5".format(timestamp)

    # 如果checkpoint文件存在,就加载之
    if args.retrain:
        logger.info("重新开始训练....")
    else:
        logger.info("基于之前的checkpoint训练...")
        _checkpoint_path = util.get_checkpoint(conf.DIR_CHECKPOINT)
        if _checkpoint_path is not None:
            model = load_model(_checkpoint_path,
                custom_objects={
                    'words_accuracy': _model.words_accuracy,
                    'Conv':Conv,
                    'AttentionLayer':AttentionLayer})
            logger.info("加载checkpoint模型[%s]", _checkpoint_path)
        else:
            logger.warning("找不到任何checkpoint,重新开始训练")

    checkpoint = ModelCheckpoint(
        filepath=checkpoint_path,
        monitor='words_accuracy',
        verbose=1,
        save_best_only=True,
        mode='max')

    early_stop = EarlyStopping(
        monitor='words_accuracy',
        patience=args.early_stop,
        verbose=1,
        mode='max')

    logger.info("Begin train开始训练:")

    # 训练STEPS_PER_EPOCH个batch,作为一个epoch,默认是10000

    model.fit_generator(
        generator=train_sequence,
        steps_per_epoch=args.steps_per_epoch,#其实应该是用len(train_sequence),但是这样太慢了,所以,我规定用一个比较小的数,比如1000
        epochs=args.epochs,
        workers=args.workers,   # 同时启动多少个进程加载
        callbacks=[TensorBoard(log_dir=tb_log_name),checkpoint,early_stop],
        use_multiprocessing=True,
        validation_data=valid_sequence,
        validation_steps=args.validation_steps)

    logger.info("Train end训练结束!")

    model_path = conf.DIR_MODEL+"/ocr-attention-{}.hdf5".format(util.timestamp_s())
    model.save(model_path)
    logger.info("Save model保存训练后的模型到:%s", model_path)