Example #1
0
def crnn_net(is_training, feature, label, batch_size, l_size):
    seq_len = l_size
    if is_training:
        shadownet = crnn_model.ShadowNet(
            phase='Train',
            hidden_nums=256,
            layers_nums=2,
            seq_length=seq_len,
            num_classes=config.cfg.TRAIN.CLASSES_NUMS,
            rnn_cell_type='lstm')

        imgs = tf.image.resize_images(feature, (32, l_size * 4), method=0)
        input_imgs = tf.cast(x=imgs, dtype=tf.float32)

        with tf.variable_scope('shadow', reuse=False):
            net_out, tensor_dict = shadownet.build_shadownet(
                inputdata=input_imgs)

        cost = tf.reduce_mean(
            tf.nn.ctc_loss(labels=label,
                           inputs=net_out,
                           sequence_length=seq_len * np.ones(batch_size)))

        # lstm l2
        lstm_tv = tf.trainable_variables(scope='LSTMLayers')
        r_lambda = 0.001
        regularization_cost = r_lambda * tf.reduce_sum(
            [tf.nn.l2_loss(v) for v in lstm_tv])
        cost = cost + regularization_cost

        model_params = tf.trainable_variables()
        tower_grad = tf.gradients(cost, model_params)

        return cost, zip(tower_grad,
                         model_params), net_out, tensor_dict, seq_len
    else:
        shadownet = crnn_model.ShadowNet(
            phase='Test',
            hidden_nums=256,
            layers_nums=2,
            seq_length=seq_len,
            num_classes=config.cfg.TRAIN.CLASSES_NUMS,
            rnn_cell_type='lstm')

        imgs = tf.image.resize_images(feature, (32, l_size * 4), method=0)
        input_imgs = tf.cast(x=imgs, dtype=tf.float32)

        with tf.variable_scope('shadow', reuse=False):
            net_out, tensor_dict = shadownet.build_shadownet(
                inputdata=input_imgs)

        cost = None

        model_params = None
        tower_grad = None

        return cost, None, net_out, tensor_dict, seq_len
Example #2
0
def recognize(image_path, weights_path, is_vis=True):
    """

    :param image_path:
    :param weights_path:
    :param is_vis:
    :return:
    """
    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    image = cv2.resize(image, (100, 32))
    image = np.expand_dims(image, axis=0).astype(np.float32)

    inputdata = tf.placeholder(dtype=tf.float32,
                               shape=[1, 32, 100, 3],
                               name='input')

    net = crnn_model.ShadowNet(phase='Test',
                               hidden_nums=256,
                               layers_nums=2,
                               seq_length=25,
                               num_classes=37)

    with tf.variable_scope('shadow'):
        net_out = net.build_shadownet(inputdata=inputdata)

    decodes, _ = tf.nn.ctc_beam_search_decoder(inputs=net_out,
                                               sequence_length=25 * np.ones(1),
                                               merge_repeated=False)

    decoder = data_utils.TextFeatureIO()

    # config tf session
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = config.cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = config.cfg.TRAIN.TF_ALLOW_GROWTH

    # config tf saver
    saver = tf.train.Saver()

    sess = tf.Session(config=sess_config)

    with sess.as_default():
        saver.restore(sess=sess, save_path=weights_path)

        preds = sess.run(decodes, feed_dict={inputdata: image})

        preds = decoder.writer.sparse_tensor_to_str(preds[0])

        logger.info('Predict image {:s} label {:s}'.format(
            ops.split(image_path)[1], preds[0]))

        if is_vis:
            plt.figure('CRNN Model Demo')
            plt.imshow(
                cv2.imread(image_path, cv2.IMREAD_COLOR)[:, :, (2, 1, 0)])
            plt.show()

        sess.close()

    return
Example #3
0
    def load_model(self):
        sess = tf.Session()
        x = tf.placeholder(dtype=tf.float32,
                           shape=[1, 32, config.cfg.TRAIN.width, 3],
                           name='input')
        phase_tensor = tf.constant('test', tf.string)
        net = crnn_model.ShadowNet(phase=phase_tensor,
                                   hidden_nums=256,
                                   layers_nums=2,
                                   seq_length=15,
                                   num_classes=config.cfg.TRAIN.CLASSES_NUMS,
                                   rnn_cell_type='lstm')
        with tf.variable_scope('shadow'):
            net_out, tensor_dict = net.build_shadownet(inputdata=x)
        decodes, _ = tf.nn.ctc_beam_search_decoder(inputs=net_out,
                                                   sequence_length=20 *
                                                   np.ones(1),
                                                   merge_repeated=False)

        saver = tf.train.Saver()
        params_file = tf.train.latest_checkpoint(self.model_dir)
        saver.restore(sess=sess, save_path=params_file)
        self.output['sess'] = sess
        self.output['x'] = x
        self.output['y_'] = decodes
Example #4
0
    def load_model(self):
        sess = tf.Session()

        x = tf.placeholder(dtype=tf.float32,
                           shape=[1, 32, 100, 3],
                           name='input')
        #define model
        net = crnn_model.ShadowNet(phase='Test',
                                   hidden_nums=256,
                                   layers_nums=2,
                                   seq_length=25,
                                   num_classes=37)
        with tf.variable_scope('shadow'):
            net_out = net.build_shadownet(inputdata=x)
        decodes, _ = tf.nn.ctc_beam_search_decoder(inputs=net_out,
                                                   sequence_length=25 *
                                                   np.ones(1),
                                                   merge_repeated=False)

        sess_config = tf.ConfigProto()
        sess_config.gpu_options.per_process_gpu_memory_fraction = config.cfg.TRAIN.GPU_MEMORY_FRACTION
        sess_config.gpu_options.allow_growth = config.cfg.TRAIN.TF_ALLOW_GROWTH
        sess = tf.Session(config=sess_config)

        saver = tf.train.Saver()
        params_file = tf.train.latest_checkpoint(self.model_dir)
        saver.restore(sess=sess, save_path=params_file)

        self.output['sess'] = sess
        self.output['x'] = x
        self.output['y_'] = decodes
Example #5
0
def recognize(image_path, weights_path, is_vis=True):

    image = cv2.imread(image_path, cv2.IMREAD_COLOR)  #读取图片
    image = cv2.resize(image, (100, 32))  #调整图片分辨率
    image = np.expand_dims(image, axis=0).astype(np.float32)  #将图片格式转为浮点型

    inputdata = tf.placeholder(dtype=tf.float32,
                               shape=[1, 32, 100, 3],
                               name='input')  #为输入数据占位
    net = crnn_model.ShadowNet(phase='Test',
                               hidden_nums=256,
                               layers_nums=2,
                               seq_length=25,
                               num_classes=37)  #声明网络的类
    with tf.variable_scope('shadow'):  #通过tf.variable_scope生成一个上下文管理器
        net_out = net.build_shadownet(inputdata=inputdata)  #创建网络,指定输入数据
    decodes, _ = tf.nn.ctc_beam_search_decoder(inputs=net_out,
                                               sequence_length=25 * np.ones(1),
                                               merge_repeated=False)  #对数据解码

    # 设置session配置参数
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = config.cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = config.cfg.TRAIN.TF_ALLOW_GROWTH

    # 初始化保存数据
    saver = tf.train.Saver()
    sess = tf.Session(config=sess_config)  #创建图运算

    with sess.as_default():  #创建一个上下文管理器
        saver.restore(sess=sess, save_path=weights_path)  #载入训练好的网络权重
        preds = sess.run(decodes, feed_dict={inputdata: image})  #网络计算
        preds = data_utils.sparse_tensor_to_str(preds[0])  #得到的结果保存为字符串型
        print('预测的图像为 %s 结果为 %s' % (ops.split(image_path)[1], preds[0]))  #打印结果

        if is_vis:  #如果在recognize()中,将is_vis=True,则显示图片
            plt.figure('CRNN 图片')
            plt.imshow(
                cv2.imread(image_path, cv2.IMREAD_COLOR)[:, :, (2, 1, 0)])
            plt.show()

        sess.close()

    return
Example #6
0
def _tower_fn(feature):
    '''
    The l_size should be compatable with the train_shadownet_multi.py --l_size param
    This is also related with the 
    '''
    l_size = 10
    shadownet = crnn_model.ShadowNet(phase='Train',
                                     hidden_nums=256,
                                     layers_nums=2,
                                     seq_length=l_size,
                                     num_classes=config.cfg.TRAIN.CLASSES_NUMS,
                                     rnn_cell_type='lstm')

    imgs = tf.image.resize_images(feature, (32, l_size * 4), method=0)
    input_imgs = tf.cast(x=imgs, dtype=tf.float32)

    with tf.variable_scope('shadow', reuse=False):
        net_out, tensor_dict = shadownet.build_shadownet(inputdata=input_imgs)

    return net_out, tensor_dict, l_size
def train_shadownet_multi_gpu(dataset_dir_train, dataset_dir_val, weights_path,
                              char_dict_path, ord_map_dict_path,
                              model_save_dir):
    """

    :param dataset_dir:
    :param weights_path:
    :param char_dict_path:
    :param ord_map_dict_path:
    :return:
    """
    # prepare dataset information
    NUM_CLASSES = get_num_class(char_dict_path)
    print(" dataset_dir_train       ", dataset_dir_train)
    """
    train_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir_train,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='train'
    )
    val_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir_train,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='valid'
    )
    
    """
    # FIXME: 以下的代码会出现问题
    train_dataset = read_tfrecord.CrnnDataFeeder(
        dataset_dir=dataset_dir_train,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='train')

    val_dataset = read_tfrecord.CrnnDataFeeder(
        dataset_dir=dataset_dir_train,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='valid')

    train_images, train_labels, train_images_paths = train_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE)
    val_images, val_labels, val_images_paths = val_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE)

    # set crnn net
    shadownet = crnn_model.ShadowNet(phase='train',
                                     hidden_nums=CFG.ARCH.HIDDEN_UNITS,
                                     layers_nums=CFG.ARCH.HIDDEN_LAYERS,
                                     num_classes=NUM_CLASSES)
    shadownet_val = crnn_model.ShadowNet(phase='test',
                                         hidden_nums=CFG.ARCH.HIDDEN_UNITS,
                                         layers_nums=CFG.ARCH.HIDDEN_LAYERS,
                                         num_classes=NUM_CLASSES)

    # set average container
    tower_grads = []
    train_tower_loss = []
    val_tower_loss = []
    batchnorm_updates = None
    train_summary_op_updates = None

    # set lr
    global_step = tf.Variable(0, name='global_step', trainable=False)
    learning_rate = tf.train.exponential_decay(
        learning_rate=CFG.TRAIN.LEARNING_RATE,
        global_step=global_step,
        decay_steps=CFG.TRAIN.LR_DECAY_STEPS,
        decay_rate=CFG.TRAIN.LR_DECAY_RATE,
        staircase=CFG.TRAIN.LR_STAIRCASE)

    # set up optimizer
    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                           momentum=0.9)

    # set distributed train op
    with tf.variable_scope(tf.get_variable_scope()):
        is_network_initialized = False
        for i in range(CFG.TRAIN.GPU_NUM):
            with tf.device('/gpu:{:d}'.format(i)):
                with tf.name_scope('tower_{:d}'.format(i)) as _:
                    train_loss, grads = compute_net_gradients(
                        train_images,
                        train_labels,
                        shadownet,
                        optimizer,
                        is_net_first_initialized=is_network_initialized)

                    is_network_initialized = True

                    # Only use the mean and var in the first gpu tower to update the parameter
                    if i == 0:
                        batchnorm_updates = tf.get_collection(
                            tf.GraphKeys.UPDATE_OPS)
                        train_summary_op_updates = tf.get_collection(
                            tf.GraphKeys.SUMMARIES)

                    tower_grads.append(grads)
                    train_tower_loss.append(train_loss)
                with tf.name_scope('validation_{:d}'.format(i)) as _:
                    val_loss, _ = compute_net_gradients(
                        val_images,
                        val_labels,
                        shadownet_val,
                        optimizer,
                        is_net_first_initialized=is_network_initialized)
                    val_tower_loss.append(val_loss)

    grads = average_gradients(tower_grads)
    avg_train_loss = tf.reduce_mean(train_tower_loss)
    avg_val_loss = tf.reduce_mean(val_tower_loss)

    # Track the moving averages of all trainable variables
    variable_averages = tf.train.ExponentialMovingAverage(
        CFG.TRAIN.MOVING_AVERAGE_DECAY, num_updates=global_step)
    variables_to_average = tf.trainable_variables(
    ) + tf.moving_average_variables()
    variables_averages_op = variable_averages.apply(variables_to_average)

    # Group all the op needed for training
    batchnorm_updates_op = tf.group(*batchnorm_updates)
    apply_gradient_op = optimizer.apply_gradients(grads,
                                                  global_step=global_step)
    train_op = tf.group(apply_gradient_op, variables_averages_op,
                        batchnorm_updates_op)

    # set tensorflow summary
    tboard_save_path = model_save_dir
    os.makedirs(tboard_save_path, exist_ok=True)

    summary_writer = tf.summary.FileWriter(tboard_save_path)

    avg_train_loss_scalar = tf.summary.scalar(name='average_train_loss',
                                              tensor=avg_train_loss)
    avg_val_loss_scalar = tf.summary.scalar(name='average_val_loss',
                                            tensor=avg_val_loss)
    learning_rate_scalar = tf.summary.scalar(name='learning_rate_scalar',
                                             tensor=learning_rate)
    train_merge_summary_op = tf.summary.merge(
        [avg_train_loss_scalar, learning_rate_scalar] +
        train_summary_op_updates)
    val_merge_summary_op = tf.summary.merge([avg_val_loss_scalar])

    # set tensorflow saver
    saver = tf.train.Saver()
    os.makedirs(model_save_dir, exist_ok=True)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'shadownet_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # set sess config
    sess_config = tf.ConfigProto(device_count={'GPU': CFG.TRAIN.GPU_NUM},
                                 allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    logger.info('Global configuration is as follows:')
    logger.info(CFG)

    sess = tf.Session(config=sess_config)

    summary_writer.add_graph(sess.graph)

    with sess.as_default():
        epoch = 0
        tf.train.write_graph(
            graph_or_graph_def=sess.graph,
            logdir='',
            name='{:s}/shadownet_model.pb'.format(model_save_dir))

        if weights_path is None or not os.path.exists(weights_path) or len(
                os.listdir(weights_path)) < 5:
            logger.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            weights_path = tf.train.latest_checkpoint(weights_path)
            logger.info('Restore model from last model checkpoint {:s}'.format(
                weights_path))
            saver.restore(sess=sess, save_path=weights_path)
            epoch = sess.run(tf.train.get_global_step())

        train_cost_time_mean = []
        val_cost_time_mean = []

        while epoch < train_epochs:
            epoch += 1
            # training part
            t_start = time.time()

            _, train_loss_value, train_summary, lr = \
                sess.run(fetches=[train_op,
                                  avg_train_loss,
                                  train_merge_summary_op,
                                  learning_rate])

            if math.isnan(train_loss_value):
                raise ValueError('Train loss is nan')

            summary_writer.add_summary(summary=train_summary,
                                       global_step=epoch)

            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                logger.info(
                    'lr={:.5f}   epoch:{:6d}   total_loss={:.5f} '.format(
                        lr,
                        epoch + 1,
                        train_loss_value,
                    ))

            if epoch % CFG.TRAIN.VAL_DISPLAY_STEP == 0:
                # validation part

                val_loss_value, val_summary = \
                    sess.run(fetches=[avg_val_loss,
                                      val_merge_summary_op])

                summary_writer.add_summary(val_summary, global_step=epoch)

                logger.info(
                    'Valid-----   epoch:{:6d}   total_loss={:.5f} '.format(
                        epoch + 1, val_loss_value))

            if epoch % CFG.TRAIN.VAL_DISPLAY_STEP == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=epoch)
    sess.close()

    return
def build_saved_model(ckpt_path, export_dir):
    """
    Convert source ckpt weights file into tensorflow saved model
    :param ckpt_path:
    :param export_dir:
    :return:
    """

    if ops.exists(export_dir):
        raise ValueError('Export dir must be a dir path that does not exist')

    assert ops.exists(ops.split(ckpt_path)[0])

    # build inference tensorflow graph
    image_size = tuple(CFG.ARCH.INPUT_SIZE)
    image_tensor = tf.placeholder(dtype=tf.float32,
                                  shape=[1, image_size[1], image_size[0], 3],
                                  name='input_tensor')

    # set crnn net
    net = crnn_model.ShadowNet(phase='test',
                               hidden_nums=CFG.ARCH.HIDDEN_UNITS,
                               layers_nums=CFG.ARCH.HIDDEN_LAYERS,
                               num_classes=CFG.ARCH.NUM_CLASSES)

    # compute inference logits
    inference_ret = net.inference(inputdata=image_tensor,
                                  name='shadow_net',
                                  reuse=False)

    # beam search decode
    decodes, _ = tf.nn.ctc_beam_search_decoder(
        inputs=inference_ret,
        sequence_length=CFG.ARCH.SEQ_LENGTH * np.ones(1),
        merge_repeated=False)

    saver = tf.train.Saver()

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    with sess.as_default():

        saver.restore(sess=sess, save_path=ckpt_path)

        # set model save builder
        saved_builder = sm.builder.SavedModelBuilder(export_dir)

        # add tensor need to be saved
        saved_input_tensor = sm.utils.build_tensor_info(image_tensor)
        saved_prediction_tensor = sm.utils.build_tensor_info(decodes[0])

        # build SignatureDef protobuf
        signatur_def = sm.signature_def_utils.build_signature_def(
            inputs={'input_tensor': saved_input_tensor},
            outputs={'prediction': saved_prediction_tensor},
            method_name=sm.signature_constants.PREDICT_METHOD_NAME)

        # add graph into MetaGraphDef protobuf
        saved_builder.add_meta_graph_and_variables(
            sess,
            tags=[sm.tag_constants.SERVING],
            signature_def_map={
                sm.signature_constants.PREDICT_OUTPUTS: signatur_def
            })

        # save model
        saved_builder.save()

    return
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Restoring and running multiple tensorflow network models needs some workaround->
https://stackoverflow.com/questions/41607144/loading-two-models-from-saver-in-the-same-tensorflow-session
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
"""

os.chdir(CRNN_DIR)
import tools.demo_shadownet as crnn
import crnn_model.crnn_model as crnn_model
crnn_weights_path = 'model/shadownet/shadownet_2017-10-17-11-47-46.ckpt-199999'

crnn_graph = tf.Graph()
with crnn_graph.as_default():
    crnn_net = crnn_model.ShadowNet(phase='Test',
                                    hidden_nums=256,
                                    layers_nums=2,
                                    seq_length=25,
                                    num_classes=37)
    with tf.variable_scope('shadow'):
        crnn_inputdata = tf.placeholder(dtype=tf.float32,
                                        shape=[1, 32, 100, 3],
                                        name='input')
        crnn_net_out = crnn_net.build_shadownet(inputdata=crnn_inputdata)

crnn_decodes, _ = tf.nn.ctc_beam_search_decoder(inputs=crnn_net_out,
                                                sequence_length=25 *
                                                np.ones(1),
                                                merge_repeated=False)
crnn_decoder = crnn.data_utils.TextFeatureIO()

# config tf session
def train_shadownet(cfg: EasyDict,
                    weights_path: str = None,
                    decode: bool = False,
                    num_threads: int = 4) -> np.array:
    """
    :param cfg: configuration EasyDict (e.g. global_config.config.cfg)
    :param weights_path: Path to stored weights
    :param decode: Whether to perform CTC decoding to report progress during training
    :param num_threads: Number of threads to use in tf.train.shuffle_batch
    :return History of values of the cost function
    """
    # decode the tf records to get the training data
    decoder = data_utils.TextFeatureIO(
        char_dict_path=ops.join(cfg.PATH.CHAR_DICT_DIR, 'char_dict.json'),
        ord_map_dict_path=ops.join(cfg.PATH.CHAR_DICT_DIR,
                                   'ord_map.json')).reader

    input_images, input_labels, input_image_names = decoder.read_features(
        cfg, cfg.TRAIN.BATCH_SIZE, num_threads)

    shadownet = crnn_model.ShadowNet(phase='Train',
                                     hidden_nums=cfg.ARCH.HIDDEN_UNITS,
                                     layers_nums=cfg.ARCH.HIDDEN_LAYERS,
                                     num_classes=len(decoder.char_dict) + 1)

    with tf.variable_scope('shadow', reuse=False):
        net_out = shadownet.build_shadownet(inputdata=input_images)

    cost = tf.reduce_mean(
        tf.nn.ctc_loss(labels=input_labels,
                       inputs=net_out,
                       sequence_length=cfg.ARCH.SEQ_LENGTH *
                       np.ones(cfg.TRAIN.BATCH_SIZE)))

    decoded, log_prob = tf.nn.ctc_beam_search_decoder(
        net_out,
        cfg.ARCH.SEQ_LENGTH * np.ones(cfg.TRAIN.BATCH_SIZE),
        merge_repeated=False)

    sequence_dist = tf.reduce_mean(
        tf.edit_distance(tf.cast(decoded[0], tf.int32), input_labels))

    global_step = tf.Variable(0, name='global_step', trainable=False)

    starter_learning_rate = cfg.TRAIN.LEARNING_RATE
    learning_rate = tf.train.exponential_decay(
        starter_learning_rate,
        global_step,
        cfg.TRAIN.LR_DECAY_STEPS,
        cfg.TRAIN.LR_DECAY_RATE,
        staircase=cfg.TRAIN.LR_STAIRCASE)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        optimizer = tf.train.AdadeltaOptimizer(
            learning_rate=learning_rate).minimize(loss=cost,
                                                  global_step=global_step)

    # Set tf summary
    os.makedirs(cfg.PATH.TBOARD_SAVE_DIR, exist_ok=True)
    tf.summary.scalar(name='Cost', tensor=cost)
    tf.summary.scalar(name='Learning_Rate', tensor=learning_rate)
    if decode:
        tf.summary.scalar(name='Seq_Dist', tensor=sequence_dist)
    merge_summary_op = tf.summary.merge_all()

    # Set saver configuration
    saver = tf.train.Saver()
    os.makedirs(cfg.PATH.TBOARD_SAVE_DIR, exist_ok=True)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'shadownet_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = ops.join(cfg.PATH.MODEL_SAVE_DIR, model_name)

    # Set sess configuration
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = cfg.TRAIN.TF_ALLOW_GROWTH

    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(cfg.PATH.TBOARD_SAVE_DIR)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = cfg.TRAIN.EPOCHS

    with sess.as_default():
        if weights_path is None:
            logger.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            logger.info('Restore model from {:s}'.format(weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        patience_counter = 1
        cost_history = [np.inf]
        for epoch in range(train_epochs):
            if epoch > 1 and cfg.TRAIN.EARLY_STOPPING:
                # We always compare to the first point where cost didn't improve
                if cost_history[-1 - patience_counter] - cost_history[
                        -1] > cfg.TRAIN.PATIENCE_DELTA:
                    patience_counter = 1
                else:
                    patience_counter += 1
                if patience_counter > cfg.TRAIN.PATIENCE_EPOCHS:
                    logger.info(
                        "Cost didn't improve beyond {:f} for {:d} epochs, stopping early."
                        .format(cfg.TRAIN.PATIENCE_DELTA, patience_counter))
                    break
            if decode:
                _, c, seq_distance, predictions, labels, summary = sess.run([
                    optimizer, cost, sequence_dist, decoded, input_labels,
                    merge_summary_op
                ])

                labels = decoder.sparse_tensor_to_str(labels)
                predictions = decoder.sparse_tensor_to_str(predictions[0])
                accuracy = compute_accuracy(labels, predictions)

                if epoch % cfg.TRAIN.DISPLAY_STEP == 0:
                    logger.info(
                        'Epoch: {:d} cost= {:9f} seq distance= {:9f} train accuracy= {:9f}'
                        .format(epoch + 1, c, seq_distance, accuracy))

            else:
                _, c, summary = sess.run([optimizer, cost, merge_summary_op])
                if epoch % cfg.TRAIN.DISPLAY_STEP == 0:
                    logger.info('Epoch: {:d} cost= {:9f}'.format(epoch + 1, c))

            cost_history.append(c)
            summary_writer.add_summary(summary=summary, global_step=epoch)
            saver.save(sess=sess, save_path=model_save_path, global_step=epoch)

        return np.array(cost_history[1:])  # Don't return the first np.inf
Example #11
0
def test_shadownet(dataset_dir, weights_path, is_vis=False, is_recursive=True):

    images_t, labels_t, imagenames_t = data_utils.read_features(
        dataset_dir, num_epochs=None)  #读取.tfrecords文件

    if not is_recursive:
        #如果设置is_recursive为flase,则创建一个乱序的数据序列。
        #capacity读取数据范围;min_after_dequeue越大,数据越乱
        images_sh, labels_sh, imagenames_sh = tf.train.shuffle_batch(
            tensors=[images_t, labels_t, imagenames_t],
            batch_size=32,
            capacity=1000 + 32 * 2,
            min_after_dequeue=2,
            num_threads=4)
    else:
        #如果设置is_recursive为True,则不打乱数据顺序
        images_sh, labels_sh, imagenames_sh = tf.train.batch(
            tensors=[images_t, labels_t, imagenames_t],
            batch_size=32,
            capacity=1000 + 32 * 2,
            num_threads=4)

    images_sh = tf.cast(x=images_sh, dtype=tf.float32)  #将图像数据类型转为float32

    # 在这里声明了创建网络的类
    net = crnn_model.ShadowNet(phase='Test',
                               hidden_nums=256,
                               layers_nums=2,
                               seq_length=25,
                               num_classes=37)

    with tf.variable_scope('shadow'):  #通过tf.variable_scope生成一个上下文管理器
        net_out = net.build_shadownet(inputdata=images_sh)  #创建网络,指定输入数据

    decoded, _ = tf.nn.ctc_beam_search_decoder(net_out,
                                               25 * np.ones(32),
                                               merge_repeated=False)  #对数据解码

    # 设置session配置参数
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = config.cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = config.cfg.TRAIN.TF_ALLOW_GROWTH

    # 初始化保存数据
    saver = tf.train.Saver()

    #创建图运算
    sess = tf.Session(config=sess_config)

    test_sample_count = 0
    for record in tf.python_io.tf_record_iterator(dataset_dir):
        test_sample_count += 1
    loops_nums = int(math.ceil(test_sample_count / 32))

    with sess.as_default():  #创建图计算的默认会话,当上下文管理器关闭时,这个对话不会关闭

        # 加载网络权重
        saver.restore(sess=sess, save_path=weights_path)

        coord = tf.train.Coordinator()  #创建一个协调器,管理线程
        threads = tf.train.start_queue_runners(
            sess=sess, coord=coord)  #启动QueueRunner, 此时文件名队列已经进队

        print('开始预测文字......')
        if not is_recursive:  #如果设置is_recursive为flase,则创建一个乱序的数据序列。,和最开始创建数据系列方式保持一致
            predictions, images, labels, imagenames = sess.run(
                [decoded, images_sh, labels_sh, imagenames_sh])  #运行图计算
            imagenames = np.reshape(imagenames, newshape=imagenames.shape[0])
            imagenames = [tmp.decode('utf-8') for tmp in imagenames]
            preds_res = data_utils.sparse_tensor_to_str(
                predictions[0])  #获取的预测文字结果
            gt_res = data_utils.sparse_tensor_to_str(labels)  #真实的结果
            accuracy = []  #用来保存准确率

            for index, gt_label in enumerate(
                    gt_res):  #enumerate方式同时获取来一个list的索引和对应元素
                pred = preds_res[index]
                totol_count = len(gt_label)
                correct_count = 0
                try:
                    for i, tmp in enumerate(
                            gt_label):  #这里逐项对比预测结果和真实结果,记录准确结果个数
                        if tmp == pred[i]:
                            correct_count += 1
                except IndexError:
                    continue
                finally:
                    try:
                        accuracy.append(correct_count /
                                        totol_count)  #错误的/全部的几位准确率
                    except ZeroDivisionError:
                        if len(pred) == 0:
                            accuracy.append(1)
                        else:
                            accuracy.append(0)

            accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
            print(' test accuracy 为 %f' % (accuracy))

            for index, image in enumerate(images):
                print('预测图片 %s 准确的label为: %s **** 预测的 label: %s' %
                      (imagenames[index], gt_res[index], preds_res[index]))
                if is_vis:
                    plt.imshow(image[:, :, (2, 1, 0)])
                    plt.show()
        else:  #这里是非乱序获取数据序列的,和上面的if对应
            accuracy = []
            for epoch in range(loops_nums):
                predictions, images, labels, imagenames = sess.run(
                    [decoded, images_sh, labels_sh, imagenames_sh])
                imagenames = np.reshape(imagenames,
                                        newshape=imagenames.shape[0])
                imagenames = [tmp.decode('utf-8') for tmp in imagenames]
                preds_res = data_utils.sparse_tensor_to_str(predictions[0])
                gt_res = data_utils.sparse_tensor_to_str(labels)

                for index, gt_label in enumerate(gt_res):
                    pred = preds_res[index]
                    totol_count = len(gt_label)
                    correct_count = 0
                    try:
                        for i, tmp in enumerate(gt_label):
                            if tmp == pred[i]:
                                correct_count += 1
                    except IndexError:
                        continue
                    finally:
                        try:
                            accuracy.append(correct_count / totol_count)
                        except ZeroDivisionError:
                            if len(pred) == 0:
                                accuracy.append(1)
                            else:
                                accuracy.append(0)

                for index, image in enumerate(images):
                    print('预测图片 %s 准确的label为: %s **** 预测的label: %s' %
                          (imagenames[index], gt_res[index], preds_res[index]))
                    if is_vis:  #如果在recognize()中,将is_vis=True,则显示图片
                        plt.imshow(image[:, :, (2, 1, 0)])
                        plt.show()

            accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
            print('Test accuracy is %f' % (accuracy))

        coord.request_stop()
        coord.join(threads=threads)

    sess.close()
    return
Example #12
0
def train_shadownet(dataset_dir, weights_path=None):

    print("读取训练数据")
    images, labels, imagenames = data_utils.read_features(
        dataset_dir, num_epochs=None)  #读取.tfrecords文件
    #创建一个乱序序列用于训练
    inputdata, input_labels, input_imagenames = tf.train.shuffle_batch(
        tensors=[images, labels, imagenames],
        batch_size=32,
        capacity=1000 + 2 * 32,
        min_after_dequeue=100,
        num_threads=1)
    inputdata = tf.cast(x=inputdata, dtype=tf.float32)  #占位

    print("初始化网络")  # 在这里声明了创建网络的类
    shadownet = crnn_model.ShadowNet(phase='Train',
                                     hidden_nums=256,
                                     layers_nums=2,
                                     seq_length=25,
                                     num_classes=37)

    with tf.variable_scope('shadow',
                           reuse=False):  #通过tf.variable_scope生成一个上下文管理器
        net_out = shadownet.build_shadownet(inputdata=inputdata)  #创建网络,指定输入数据

    cost = tf.reduce_mean(
        tf.nn.ctc_loss(labels=input_labels,
                       inputs=net_out,
                       sequence_length=25 * np.ones(32)))  #按照设定的维度求张量平均值
    decoded, log_prob = tf.nn.ctc_beam_search_decoder(
        net_out, 25 * np.ones(32), merge_repeated=False)  #对数据解码
    sequence_dist = tf.reduce_mean(
        tf.edit_distance(tf.cast(decoded[0], tf.int32),
                         input_labels))  #按照设定的维度求张量平均值
    global_step = tf.Variable(0, name='global_step', trainable=False)  #初始化图变量
    starter_learning_rate = config.cfg.TRAIN.LEARNING_RATE  #设定初始学习速率
    learning_rate = tf.train.exponential_decay(starter_learning_rate,
                                               global_step,
                                               config.cfg.TRAIN.LR_DECAY_STEPS,
                                               config.cfg.TRAIN.LR_DECAY_RATE,
                                               staircase=True)  #按照指数衰减方式改变学习速率
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  #设定为全局变量

    with tf.control_dependencies(update_ops):  #Adadelta算法的优化器
        optimizer = tf.train.AdadeltaOptimizer(
            learning_rate=learning_rate).minimize(loss=cost,
                                                  global_step=global_step)

    # 设置tensorflow的模型管理模式
    tboard_save_path = 'tboard/shadownet'
    if not ops.exists(tboard_save_path):
        os.makedirs(tboard_save_path)
    tf.summary.scalar(name='Cost', tensor=cost)
    tf.summary.scalar(name='Learning_Rate', tensor=learning_rate)
    tf.summary.scalar(name='Seq_Dist', tensor=sequence_dist)
    merge_summary_op = tf.summary.merge_all()  #自动管理模式,导入之前已经保存的模型继续训练

    # 设置模型保存路径
    saver = tf.train.Saver()
    model_save_dir = 'model/shadownet'
    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'shadownet_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # gpu参数
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = config.cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = config.cfg.TRAIN.TF_ALLOW_GROWTH

    sess = tf.Session(config=sess_config)  #创建图运算

    summary_writer = tf.summary.FileWriter(tboard_save_path)
    summary_writer.add_graph(sess.graph)

    # 迭代次数
    train_epochs = config.cfg.TRAIN.EPOCHS

    print("开始训练")
    with sess.as_default():
        if weights_path is None:
            print('完全重新开始训练')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            print('在之前的模型:' + 'weights_path' + '上继续训练')
            saver.restore(sess=sess, save_path=weights_path)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for epoch in range(train_epochs):
            _, c, seq_distance, preds, gt_labels, summary = sess.run([
                optimizer, cost, sequence_dist, decoded, input_labels,
                merge_summary_op
            ])

            # calculate the precision
            preds = data_utils.sparse_tensor_to_str(preds[0])
            gt_labels = data_utils.sparse_tensor_to_str(gt_labels)

            accuracy = []

            for index, gt_label in enumerate(gt_labels):
                pred = preds[index]
                totol_count = len(gt_label)
                correct_count = 0
                try:
                    for i, tmp in enumerate(gt_label):
                        if tmp == pred[i]:
                            correct_count += 1
                except IndexError:
                    continue
                finally:
                    try:
                        accuracy.append(correct_count / totol_count)
                    except ZeroDivisionError:
                        if len(pred) == 0:
                            accuracy.append(1)
                        else:
                            accuracy.append(0)
            accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
            #
            if epoch % config.cfg.TRAIN.DISPLAY_STEP == 0:
                print(
                    'Epoch: %d cost= %f seq distance= %f train accuracy= %f' %
                    (epoch + 1, c, seq_distance, accuracy))

            summary_writer.add_summary(summary=summary, global_step=epoch)
            saver.save(sess=sess, save_path=model_save_path, global_step=epoch)

        coord.request_stop()
        coord.join(threads=threads)

    sess.close()

    return
Example #13
0
def recognize_jmz(image_path, weights_path, char_dict_path, txt_file_path,
                  test_number):
    """
    识别函数
    :param image_path: 图片所在路径
    :param weights_path: 模型保存路径
    :param char_dict_path: 字典文件存放位置
    :param txt_file_path: 包含图片名的txt文件
    :return: None
    """

    char_map_dict = json.load(open(char_dict_path, 'r', encoding='utf-8'))
    num_classes = len(char_map_dict) + 1
    print('num_classes: ', num_classes)

    with open(txt_file_path, 'r') as f1:
        linelist = f1.readlines()

    image_list = []
    for i in range(test_number):
        image_path_temp = image_path + linelist[i].split(' ')[0]
        image_list.append((image_path_temp, linelist[i].split(' ')[1].replace(
            '\r', '').replace('\n', '').replace('\t', '')))

    global reg_result
    tf.reset_default_graph()

    inputdata = tf.placeholder(
        dtype=tf.float32,
        shape=[1, CFG.ARCH.INPUT_SIZE[1], None,
               CFG.ARCH.INPUT_CHANNELS],  # 宽度可变
        name='input')
    input_sequence_length = tf.placeholder(tf.int32,
                                           shape=[1],
                                           name='input_sequence_length')

    net = crnn_model.ShadowNet(phase='test',
                               hidden_nums=CFG.ARCH.HIDDEN_UNITS,
                               layers_nums=CFG.ARCH.HIDDEN_LAYERS,
                               num_classes=num_classes)

    inference_ret = net.inference(inputdata=inputdata,
                                  name='shadow_net',
                                  reuse=False)

    decodes, _ = tf.nn.ctc_beam_search_decoder(
        inputs=inference_ret,
        sequence_length=input_sequence_length,  # 序列宽度可变
        merge_repeated=False,
        beam_width=1)

    # config tf saver
    saver = tf.train.Saver()

    # config tf session
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    # sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    # sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH

    sess_config.gpu_options.allow_growth = True
    sess = tf.Session(config=sess_config)
    weights_path = tf.train.latest_checkpoint(weights_path)
    print('Restore model from last model checkpoint {:s}'.format(weights_path))

    with sess.as_default():
        saver.restore(sess=sess, save_path=weights_path)

        for image_name, label in image_list:
            image = cv2.imread(image_name, cv2.IMREAD_COLOR)
            if image is None:
                print(image_name + 'is not exist')
                continue


#             image = _resize_image(image)
            image = cv2.resize(image,
                               dsize=tuple(CFG.ARCH.INPUT_SIZE),
                               interpolation=cv2.INTER_LINEAR)
            image = np.array(image, np.float32) / 127.5 - 1.0
            seq_len = np.array([image.shape[1] / 4], dtype=np.int32)
            preds = sess.run(decodes,
                             feed_dict={
                                 inputdata: [image],
                                 input_sequence_length: seq_len
                             })
            preds = _sparse_matrix_to_list(preds[0], char_map_dict)
            print('Label: [{:20s}]'.format(label))
            print('Pred : [{}]\n'.format(preds[0]))

    sess.close()

    return
def train_shadownet(dataset_dir_train, dataset_dir_val, weights_path,
                    char_dict_path, model_save_dir):
    """
    训练网络,参考:
    https://github.com/MaybeShewill-CV/CRNN_Tensorflow
    :param dataset_dir: tfrecord文件路径
    :param weights_path: 要加载的预训练模型路径
    :param char_dict_path: 字典文件路径
    :param save_path: 模型保存路径
    :return: None
    """
    # prepare dataset
    train_dataset = read_tfrecord.CrnnDataFeeder(dataset_dir=dataset_dir_train,
                                                 char_dict_path=char_dict_path,
                                                 flags='train')

    train_images, train_labels, train_images_paths = train_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE)

    ####################添加数据增强##############################
    # train_images = tf.multiply(tf.add(train_images, 1.0), 128.0)   # removed since read_tfrecord.py is changed
    tf.summary.image('original_image', train_images)  # 保存到log,方便测试观察
    images = apply_with_random_selector(
        train_images,
        lambda x, ordering: distort_color(x, ordering),
        num_cases=2)  #
    images = tf.subtract(tf.divide(images, 127.5),
                         1.0)  # 转化到【-1,1】 changed 128.0 to 127.5
    train_images = tf.clip_by_value(images, -1.0, 1.0)
    tf.summary.image('distord_turned_image', train_images)
    ################################################################

    NUM_CLASSES = get_num_class(char_dict_path)

    # declare crnn net
    shadownet = crnn_model.ShadowNet(phase='train',
                                     hidden_nums=CFG.ARCH.HIDDEN_UNITS,
                                     layers_nums=CFG.ARCH.HIDDEN_LAYERS,
                                     num_classes=NUM_CLASSES)

    # set up training graph
    with tf.device('/gpu:0'):
        # compute loss and seq distance
        train_inference_ret, train_ctc_loss = shadownet.compute_loss(
            inputdata=train_images,
            labels=train_labels,
            name='shadow_net',
            reuse=False)

        # set learning rate
        global_step = tf.Variable(0, name='global_step', trainable=False)
        learning_rate = tf.train.exponential_decay(
            learning_rate=CFG.TRAIN.LEARNING_RATE,
            global_step=global_step,
            decay_steps=CFG.TRAIN.LR_DECAY_STEPS,
            decay_rate=CFG.TRAIN.LR_DECAY_RATE,
            staircase=True)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            #optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
            #    momentum=0.9).minimize(loss=train_ctc_loss, global_step=global_step)
            optimizer = tf.train.AdadeltaOptimizer(learning_rate=\
                learning_rate).minimize(loss=train_ctc_loss, global_step=global_step)
            # 源代码优化器是momentum,改成adadelta,与CRNN论文一致

    # Set tf summary
    os.makedirs(save_path, exist_ok=True)
    tf.summary.scalar(name='train_ctc_loss', tensor=train_ctc_loss)
    tf.summary.scalar(name='learning_rate', tensor=learning_rate)
    merge_summary_op = tf.summary.merge_all()

    # Set saver configuration
    saver = tf.train.Saver()
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'shadownet_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(model_save_dir)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    with sess.as_default():
        epoch = 0
        if weights_path is None:
            print('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            weights_path = tf.train.latest_checkpoint(weights_path)
            print('Restore model from last model checkpoint {:s}'.format(
                weights_path))
            saver.restore(sess=sess, save_path=weights_path)
            epoch = sess.run(tf.train.get_global_step())

        cost_history = [np.inf]
        while epoch < train_epochs:
            epoch += 1
            _, train_ctc_loss_value, merge_summary_value, learning_rate_value = sess.run(
                [optimizer, train_ctc_loss, merge_summary_op, learning_rate])

            if (epoch + 1) % CFG.TRAIN.DISPLAY_STEP == 0:

                current_time = time.strftime('%m-%d-%H-%M-%S',
                                             time.localtime(time.time()))
                print('{} lr={:.5f}  step:{:6d}   train_loss={:.4f}'.format(\
                    current_time, learning_rate_value, epoch+1, train_ctc_loss_value))

                # record history train ctc loss
                cost_history.append(train_ctc_loss_value)
                # add training sumary
                summary_writer.add_summary(summary=merge_summary_value,
                                           global_step=epoch)

            if (epoch + 1) % CFG.TRAIN.SAVE_STEPS == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=epoch)

    return np.array(cost_history[1:])  # Don't return the first np.inf
def train_shadownet(dataset_dir, weights_path=None):
    """
    :param dataset_dir:
    :param weights_path:
    :return:
    """
    # decode the tf records to get the training data

    # initializa the net model
    shadownet = crnn_model.ShadowNet(phase='Train',
                                     hidden_nums=256,
                                     layers_nums=2,
                                     seq_length=25,
                                     num_classes=37)
    inputdata = tf.placeholder(dtype=tf.float32, shape=(32, 32, 100, 3))
    input_labels = tf.sparse_placeholder(tf.int32, shape=(None, -1))
    with tf.variable_scope('shadow', reuse=False):
        net_out = shadownet.build_shadownet(inputdata=inputdata)

    cost = tf.reduce_mean(
        tf.nn.ctc_loss(labels=input_labels,
                       inputs=net_out,
                       sequence_length=25 * np.ones(32)))

    decoded, log_prob = tf.nn.ctc_beam_search_decoder(net_out,
                                                      25 * np.ones(32),
                                                      merge_repeated=False)

    sequence_dist = tf.reduce_mean(
        tf.edit_distance(tf.cast(decoded[0], tf.int32), input_labels))

    global_step = tf.Variable(0, name='global_step', trainable=False)

    starter_learning_rate = config.cfg.TRAIN.LEARNING_RATE
    learning_rate = tf.train.exponential_decay(starter_learning_rate,
                                               global_step,
                                               config.cfg.TRAIN.LR_DECAY_STEPS,
                                               config.cfg.TRAIN.LR_DECAY_RATE,
                                               staircase=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        optimizer = tf.train.AdadeltaOptimizer(
            learning_rate=learning_rate).minimize(loss=cost,
                                                  global_step=global_step)
    # Set tf summary
    tboard_save_path = 'tboard/shadownet'
    if not ops.exists(tboard_save_path):
        os.makedirs(tboard_save_path)
    tf.summary.scalar(name='Cost', tensor=cost)
    tf.summary.scalar(name='Learning_Rate', tensor=learning_rate)
    tf.summary.scalar(name='Seq_Dist', tensor=sequence_dist)
    merge_summary_op = tf.summary.merge_all()
    # Set saver configuration
    saver = tf.train.Saver()
    model_save_dir = 'model/shadownet'
    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'shadownet_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)
    # Set sess configuration
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = config.cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = config.cfg.TRAIN.TF_ALLOW_GROWTH
    sess = tf.Session(config=sess_config)
    summary_writer = tf.summary.FileWriter(tboard_save_path)
    summary_writer.add_graph(sess.graph)
    # Set the training parameters
    train_epochs = config.cfg.TRAIN.EPOCHS
    with sess.as_default():
        if weights_path is None:
            logger.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            logger.info('Restore model from {:s}'.format(weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for epoch in range(train_epochs):
            try:
                training_img, training_label = get_training_data()
                feed_dict = {
                    inputdata: training_img,
                    input_labels: training_label
                }
                _, c, seq_distance, preds, gt_labels, summary = sess.run(
                    [
                        optimizer, cost, sequence_dist, decoded, input_labels,
                        merge_summary_op
                    ],
                    feed_dict=feed_dict)
                # calculate the precision
                preds_sequence = preds[0].values.tolist()
                gt_value = gt_labels.values.tolist()
                pre_count = len(preds_sequence)
                accu_num = 0
                gt_count = len(gt_value)
                for index in range(gt_count):
                    if index < pre_count:
                        if gt_value[index] is not None and preds_sequence[
                                index] is not None:
                            if gt_value[index] == preds_sequence[index]:
                                accu_num += 1
                accuracy = accu_num * 1.0 / pre_count

                if epoch % config.cfg.TRAIN.DISPLAY_STEP == 0:
                    logger.info(
                        'Epoch: {:d} cost= {:9f} seq distance= {:9f} train accuracy= {:9f}'
                        .format(epoch + 1, c, seq_distance, accuracy))
                if epoch % 1000 == 0 and epoch != 0:
                    summary_writer.add_summary(summary=summary,
                                               global_step=epoch)
                    saver.save(sess=sess,
                               save_path=model_save_path,
                               global_step=epoch)
                    logger.info(
                        'save_model!!!!!!!!!!!!!!!!!!!___________________')
            except Exception as e:
                print(e)
        coord.request_stop()
        coord.join(threads=threads)
    sess.close()

    return
def test_shadownet(dataset_dir, weights_path, is_vis=True):
    """

    :param dataset_dir:
    :param weights_path:
    :param is_vis:
    :return:
    """
    # Initialize the record decoder
    decoder = data_utils.TextFeatureIO().reader
    images_t, labels_t, imagenames_t = decoder.read_features(ops.join(
        dataset_dir, 'test_feature.tfrecords'),
                                                             num_epochs=None)
    images_sh, labels_sh, imagenames_sh = tf.train.shuffle_batch(
        tensors=[images_t, labels_t, imagenames_t],
        batch_size=32,
        capacity=1000 + 32 * 2,
        min_after_dequeue=2,
        num_threads=4)

    images_sh = tf.cast(x=images_sh, dtype=tf.float32)

    # build shadownet
    net = crnn_model.ShadowNet(phase='Test',
                               hidden_nums=256,
                               layers_nums=2,
                               seq_length=25,
                               num_classes=37)

    with tf.variable_scope('shadow'):
        net_out = net.build_shadownet(inputdata=images_sh)

    decoded, _ = tf.nn.ctc_beam_search_decoder(net_out,
                                               25 * np.ones(32),
                                               merge_repeated=False)

    # config tf session
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = config.cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = config.cfg.TRAIN.TF_ALLOW_GROWTH

    # config tf saver
    saver = tf.train.Saver()

    sess = tf.Session(config=sess_config)

    with sess.as_default():

        # restore the model weights
        saver.restore(sess=sess, save_path=weights_path)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        print('Start predicting ......')
        predictions, images, labels, imagenames = sess.run(
            [decoded, images_sh, labels_sh, imagenames_sh])
        imagenames = np.reshape(imagenames, newshape=imagenames.shape[0])
        imagenames = [tmp.decode('utf-8') for tmp in imagenames]
        preds_res = decoder.sparse_tensor_to_str(predictions[0])
        gt_res = decoder.sparse_tensor_to_str(labels)
        for index, image in enumerate(images):
            print(
                'Predict {:s} image with gt label: {:s} **** predict label: {:s}'
                .format(imagenames[index], gt_res[index], preds_res[index]))
            if is_vis:
                plt.imshow(image[:, :, (2, 1, 0)])
                plt.show()

        coord.request_stop()
        coord.join(threads=threads)

    sess.close()
    return
def evaluate_shadownet(dataset_dir,
                       weights_path,
                       char_dict_path,
                       ord_map_dict_path,
                       is_visualize=False,
                       is_process_all_data=False):
    """

    :param dataset_dir:
    :param weights_path:
    :param char_dict_path:
    :param ord_map_dict_path:
    :param is_visualize:
    :param is_process_all_data:
    :return:
    """
    # prepare dataset
    test_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='test')
    test_images, test_labels, test_images_paths = test_dataset.inputs(
        batch_size=CFG.TEST.BATCH_SIZE, num_epochs=1)

    # set up test sample count
    if is_process_all_data:
        log.info('Start computing test dataset sample counts')
        t_start = time.time()
        test_sample_count = test_dataset.sample_counts()
        log.info(
            'Computing test dataset sample counts finished, cost time: {:.5f}'.
            format(time.time() - t_start))
        num_iterations = int(math.ceil(test_sample_count /
                                       CFG.TEST.BATCH_SIZE))
    else:
        num_iterations = 1

    # declare crnn net
    shadownet = crnn_model.ShadowNet(phase='test',
                                     hidden_nums=CFG.ARCH.HIDDEN_UNITS,
                                     layers_nums=CFG.ARCH.HIDDEN_LAYERS,
                                     num_classes=CFG.ARCH.NUM_CLASSES)
    # set up decoder
    decoder = tf_io_pipline_tools.TextFeatureIO(
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path).reader

    # compute inference result
    test_inference_ret = shadownet.inference(inputdata=test_images,
                                             name='shadow_net',
                                             reuse=False)
    test_decoded, test_log_prob = tf.nn.ctc_beam_search_decoder(
        test_inference_ret,
        CFG.ARCH.SEQ_LENGTH * np.ones(CFG.TEST.BATCH_SIZE),
        beam_width=1,
        merge_repeated=False)

    # recover image from [-1.0, 1.0] ---> [0.0, 255.0]
    test_images = tf.multiply(tf.add(test_images, 1.0),
                              127.5,
                              name='recoverd_test_images')

    # Set saver configuration
    saver = tf.train.Saver()

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH

    sess = tf.Session(config=sess_config)

    with sess.as_default():
        saver.restore(sess=sess, save_path=weights_path)

        log.info('Start predicting...')

        per_char_accuracy = 0.0
        full_sequence_accuracy = 0.0

        total_labels_char_list = []
        total_predictions_char_list = []

        while True:
            try:

                for epoch in range(num_iterations):
                    test_predictions_value, test_images_value, test_labels_value, \
                    test_images_paths_value = sess.run(
                        [test_decoded, test_images, test_labels, test_images_paths]
                    )
                    test_images_paths_value = np.reshape(
                        test_images_paths_value,
                        newshape=test_images_paths_value.shape[0])
                    test_images_paths_value = [
                        tmp.decode('utf-8') for tmp in test_images_paths_value
                    ]
                    test_images_names_value = [
                        ops.split(tmp)[1] for tmp in test_images_paths_value
                    ]
                    test_labels_value = decoder.sparse_tensor_to_str(
                        test_labels_value)
                    test_predictions_value = decoder.sparse_tensor_to_str(
                        test_predictions_value[0])

                    per_char_accuracy += evaluation_tools.compute_accuracy(
                        test_labels_value,
                        test_predictions_value,
                        display=False,
                        mode='per_char')
                    full_sequence_accuracy += evaluation_tools.compute_accuracy(
                        test_labels_value,
                        test_predictions_value,
                        display=False,
                        mode='full_sequence')

                    for index, test_image in enumerate(test_images_value):
                        log.info(
                            'Predict {:s} image with gt label: {:s} **** predicted label: {:s}'
                            .format(test_images_names_value[index],
                                    test_labels_value[index],
                                    test_predictions_value[index]))

                        if is_visualize:
                            plt.imshow(
                                np.array(test_image, np.uint8)[:, :,
                                                               (2, 1, 0)])
                            plt.show()

                        test_labels_char_list_value = [
                            s for s in test_labels_value[index]
                        ]
                        test_predictions_char_list_value = [
                            s for s in test_predictions_value[index]
                        ]

                        if not test_labels_char_list_value or not test_predictions_char_list_value:
                            continue

                        if len(test_labels_char_list_value) != len(
                                test_predictions_char_list_value):
                            min_length = min(
                                len(test_labels_char_list_value),
                                len(test_predictions_char_list_value))
                            test_labels_char_list_value = test_labels_char_list_value[:
                                                                                      min_length
                                                                                      -
                                                                                      1]
                            test_predictions_char_list_value = test_predictions_char_list_value[:
                                                                                                min_length
                                                                                                -
                                                                                                1]

                        assert len(test_labels_char_list_value) == len(test_predictions_char_list_value), \
                            log.error('{}, {}'.format(test_labels_char_list_value, test_predictions_char_list_value))

                        total_labels_char_list.extend(
                            test_labels_char_list_value)
                        total_predictions_char_list.extend(
                            test_predictions_char_list_value)

                        if is_visualize:
                            plt.imshow(
                                np.array(test_image, np.uint8)[:, :,
                                                               (2, 1, 0)])

            except tf.errors.OutOfRangeError:
                log.error('End of tfrecords sequence')
                break
            except Exception as err:
                log.error(err)
                break

        avg_per_char_accuracy = per_char_accuracy / num_iterations
        avg_full_sequence_accuracy = full_sequence_accuracy / num_iterations
        log.info('Mean test per char accuracy is {:5f}'.format(
            avg_per_char_accuracy))
        log.info('Mean test full sequence accuracy is {:5f}'.format(
            avg_full_sequence_accuracy))

        # compute confusion matrix
        cnf_matrix = confusion_matrix(total_labels_char_list,
                                      total_predictions_char_list)
        np.set_printoptions(precision=2)
        evaluation_tools.plot_confusion_matrix(cm=cnf_matrix, normalize=True)

        plt.show()
Example #18
0
def train_shadownet(cfg: EasyDict,
                    weights_path: str = None,
                    decode: bool = False,
                    num_threads: int = 4):
    """
    :param cfg: configuration EasyDict (e.g. global_config.config.cfg)
    :param weights_path: Path to stored weights
    :param decode: Whether to perform CTC decoding to report progress during training
    :param num_threads: Number of threads to use in tf.train.shuffle_batch
    """
    # decode the tf records to get the training data
    decoder = data_utils.TextFeatureIO(
        char_dict_path=ops.join(cfg.PATH.CHAR_DICT_DIR, 'char_dict.json'),
        ord_map_dict_path=ops.join(cfg.PATH.CHAR_DICT_DIR,
                                   'ord_map.json')).reader
    images, labels, imagenames = decoder.read_features(
        ops.join(cfg.PATH.TFRECORDS_DIR, 'train_feature.tfrecords'),
        num_epochs=None,
        input_size=cfg.ARCH.INPUT_SIZE,
        input_channels=cfg.ARCH.INPUT_CHANNELS)
    inputdata, input_labels, input_imagenames = tf.train.shuffle_batch(
        tensors=[images, labels, imagenames],
        batch_size=cfg.TRAIN.BATCH_SIZE,
        capacity=1000 + 2 * cfg.TRAIN.BATCH_SIZE,
        min_after_dequeue=100,
        num_threads=num_threads)

    inputdata = tf.cast(x=inputdata, dtype=tf.float32)

    # initialise the net model
    shadownet = crnn_model.ShadowNet(phase='Train',
                                     hidden_nums=cfg.ARCH.HIDDEN_UNITS,
                                     layers_nums=cfg.ARCH.HIDDEN_LAYERS,
                                     num_classes=len(decoder.char_dict) + 1)

    with tf.variable_scope('shadow', reuse=False):
        net_out = shadownet.build_shadownet(inputdata=inputdata)

    cost = tf.reduce_mean(
        tf.nn.ctc_loss(labels=input_labels,
                       inputs=net_out,
                       sequence_length=cfg.ARCH.SEQ_LENGTH *
                       np.ones(cfg.TRAIN.BATCH_SIZE)))

    decoded, log_prob = tf.nn.ctc_beam_search_decoder(
        net_out,
        cfg.ARCH.SEQ_LENGTH * np.ones(cfg.TRAIN.BATCH_SIZE),
        merge_repeated=False)

    sequence_dist = tf.reduce_mean(
        tf.edit_distance(tf.cast(decoded[0], tf.int32), input_labels))

    global_step = tf.Variable(0, name='global_step', trainable=False)

    starter_learning_rate = cfg.TRAIN.LEARNING_RATE
    learning_rate = tf.train.exponential_decay(starter_learning_rate,
                                               global_step,
                                               cfg.TRAIN.LR_DECAY_STEPS,
                                               cfg.TRAIN.LR_DECAY_RATE,
                                               staircase=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        optimizer = tf.train.AdadeltaOptimizer(
            learning_rate=learning_rate).minimize(loss=cost,
                                                  global_step=global_step)

    # Set tf summary
    os.makedirs(cfg.PATH.TBOARD_SAVE_DIR, exist_ok=True)
    tf.summary.scalar(name='Cost', tensor=cost)
    tf.summary.scalar(name='Learning_Rate', tensor=learning_rate)
    tf.summary.scalar(name='Seq_Dist', tensor=sequence_dist)
    merge_summary_op = tf.summary.merge_all()

    # Set saver configuration
    saver = tf.train.Saver()
    os.makedirs(cfg.PATH.TBOARD_SAVE_DIR, exist_ok=True)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'shadownet_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = ops.join(cfg.PATH.MODEL_SAVE_DIR, model_name)

    # Set sess configuration
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = cfg.TRAIN.TF_ALLOW_GROWTH

    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(cfg.PATH.TBOARD_SAVE_DIR)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = cfg.TRAIN.EPOCHS

    with sess.as_default():
        if weights_path is None:
            logger.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            logger.info('Restore model from {:s}'.format(weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for epoch in range(train_epochs):
            if decode:
                _, c, seq_distance, predictions, labels, summary = sess.run([
                    optimizer, cost, sequence_dist, decoded, input_labels,
                    merge_summary_op
                ])

                labels = decoder.sparse_tensor_to_str(labels)
                predictions = decoder.sparse_tensor_to_str(predictions[0])
                accuracy = compute_accuracy(labels, predictions)

                if epoch % cfg.TRAIN.DISPLAY_STEP == 0:
                    logger.info(
                        'Epoch: {:d} cost= {:9f} seq distance= {:9f} train accuracy= {:9f}'
                        .format(epoch + 1, c, seq_distance, accuracy))

            else:
                _, c, summary = sess.run([optimizer, cost, merge_summary_op])
                if epoch % cfg.TRAIN.DISPLAY_STEP == 0:
                    logger.info('Epoch: {:d} cost= {:9f}'.format(epoch + 1, c))

            summary_writer.add_summary(summary=summary, global_step=epoch)
            saver.save(sess=sess, save_path=model_save_path, global_step=epoch)

        coord.request_stop()
        coord.join(threads=threads)
Example #19
0
def test_shadownet(weights_path: str,
                   cfg: EasyDict,
                   visualize: bool,
                   process_all_data: bool = True,
                   num_threads: int = 4,
                   num_classes: int = 0):
    """

    :param tfrecords_dir: Directory with test_feature.tfrecords
    :param charset_dir: Path to char_dict.json and ord_map.json (generated with write_text_features.py)
    :param weights_path: Path to stored weights
    :param cfg: configuration EasyDict (e.g. global_config.config.cfg)
    :param visualize: whether to display the images
    :param process_all_data:
    :param num_threads: Number of threads for tf.train.(shuffle_)batch
    :param num_classes: Number of different characters in the dataset
    """
    decoder = data_utils.TextFeatureIO(
        char_dict_path=ops.join(cfg.PATH.CHAR_DICT_DIR, 'char_dict.json'),
        ord_map_dict_path=ops.join(cfg.PATH.CHAR_DICT_DIR,
                                   'ord_map.json')).reader
    input_images, input_labels, input_image_names = decoder.read_features(
        cfg, cfg.TEST.BATCH_SIZE, num_threads, False)

    num_classes = len(
        decoder.char_dict) + 1 if num_classes == 0 else num_classes
    net = crnn_model.ShadowNet(phase='Test',
                               hidden_nums=cfg.ARCH.HIDDEN_UNITS,
                               layers_nums=cfg.ARCH.HIDDEN_LAYERS,
                               num_classes=num_classes)

    with tf.variable_scope('shadow'):
        net_out = net.build_shadownet(inputdata=input_images)

    decoded, _ = tf.nn.ctc_beam_search_decoder(net_out,
                                               cfg.ARCH.SEQ_LENGTH *
                                               np.ones(cfg.TEST.BATCH_SIZE),
                                               merge_repeated=False)

    # config tf session
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = cfg.TRAIN.TF_ALLOW_GROWTH

    # config tf saver
    saver = tf.train.Saver()

    sess = tf.Session(config=sess_config)

    test_sample_count = sum(1 for _ in tf.python_io.tf_record_iterator(
        ops.join(cfg.PATH.TFRECORDS_DIR, 'test_feature.tfrecords')))
    num_iterations = int(math.ceil(test_sample_count / cfg.TEST.BATCH_SIZE)) if process_all_data \
        else 1

    with sess.as_default():
        saver.restore(sess=sess, save_path=weights_path)

        print('Start predicting...')

        accuracy = 0
        for epoch in range(num_iterations):
            predictions, images, labels, image_names = sess.run(
                [decoded, input_images, input_labels, input_image_names])
            image_names = np.reshape(image_names,
                                     newshape=image_names.shape[0])
            image_names = [tmp.decode('utf-8') for tmp in image_names]

            labels = decoder.sparse_tensor_to_str(labels)
            predictions = decoder.sparse_tensor_to_str(predictions[0])

            accuracy += compute_accuracy(labels, predictions, display=False)

            for index, image in enumerate(images):
                print(
                    'Predict {:s} image with gt label: {:s} **** predicted label: {:s}'
                    .format(image_names[index], labels[index],
                            predictions[index]))
                # avoid accidentally displaying for the whole dataset
                if visualize and not process_all_data:
                    plt.imshow(image[:, :, (2, 1, 0)])
                    plt.show()

        # We compute a mean of means, so we need the sample sizes to be constant
        # (BATCH_SIZE) for this to equal the actual mean
        accuracy /= num_iterations
        print('Mean test accuracy is {:5f}'.format(accuracy))
Example #20
0
def train_shadownet(dataset_dir, weights_path=None):
    """

    :param dataset_dir:
    :param weights_path:
    :return:
    """
    # decode the tf records to get the training data
    decoder = data_utils.TextFeatureIO().reader
    images, labels, imagenames = decoder.read_features(ops.join(
        dataset_dir, 'train_feature.tfrecords'),
                                                       num_epochs=None)
    inputdata, input_labels, input_imagenames = tf.train.shuffle_batch(
        tensors=[images, labels, imagenames],
        batch_size=32,
        capacity=1000 + 2 * 32,
        min_after_dequeue=100,
        num_threads=1)

    inputdata = tf.cast(x=inputdata, dtype=tf.float32)

    # initializa the net model
    shadownet = crnn_model.ShadowNet(phase='Train',
                                     hidden_nums=256,
                                     layers_nums=2,
                                     seq_length=25,
                                     num_classes=37)

    with tf.variable_scope('shadow', reuse=False):
        net_out = shadownet.build_shadownet(inputdata=inputdata)

    cost = tf.reduce_mean(
        tf.nn.ctc_loss(labels=input_labels,
                       inputs=net_out,
                       sequence_length=25 * np.ones(32)))

    decoded, log_prob = tf.nn.ctc_beam_search_decoder(net_out,
                                                      25 * np.ones(32),
                                                      merge_repeated=False)

    sequence_dist = tf.reduce_mean(
        tf.edit_distance(tf.cast(decoded[0], tf.int32), input_labels))

    global_step = tf.Variable(0, name='global_step', trainable=False)

    starter_learning_rate = config.cfg.TRAIN.LEARNING_RATE
    learning_rate = tf.train.exponential_decay(starter_learning_rate,
                                               global_step,
                                               config.cfg.TRAIN.LR_DECAY_STEPS,
                                               config.cfg.TRAIN.LR_DECAY_RATE,
                                               staircase=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        optimizer = tf.train.AdadeltaOptimizer(
            learning_rate=learning_rate).minimize(loss=cost,
                                                  global_step=global_step)

    # Set tf summary
    tboard_save_path = 'tboard/shadownet'
    if not ops.exists(tboard_save_path):
        os.makedirs(tboard_save_path)
    tf.summary.scalar(name='Cost', tensor=cost)
    tf.summary.scalar(name='Learning_Rate', tensor=learning_rate)
    tf.summary.scalar(name='Seq_Dist', tensor=sequence_dist)
    merge_summary_op = tf.summary.merge_all()

    # Set saver configuration
    saver = tf.train.Saver()
    model_save_dir = 'model/shadownet'
    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'shadownet_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # Set sess configuration
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = config.cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = config.cfg.TRAIN.TF_ALLOW_GROWTH

    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(tboard_save_path)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = config.cfg.TRAIN.EPOCHS

    with sess.as_default():
        if weights_path is None:
            logger.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            logger.info('Restore model from {:s}'.format(weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for epoch in range(train_epochs):
            _, c, seq_distance, preds, gt_labels, summary = sess.run([
                optimizer, cost, sequence_dist, decoded, input_labels,
                merge_summary_op
            ])

            # calculate the precision
            preds = decoder.sparse_tensor_to_str(preds[0])
            gt_labels = decoder.sparse_tensor_to_str(gt_labels)

            accuracy = []

            for index, gt_label in enumerate(gt_labels):
                pred = preds[index]
                totol_count = len(gt_label)
                correct_count = 0
                try:
                    for i, tmp in enumerate(gt_label):
                        if tmp == pred[i]:
                            correct_count += 1
                except IndexError:
                    continue
                finally:
                    try:
                        accuracy.append(correct_count / totol_count)
                    except ZeroDivisionError:
                        if len(pred) == 0:
                            accuracy.append(1)
                        else:
                            accuracy.append(0)
            accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
            #
            if epoch % config.cfg.TRAIN.DISPLAY_STEP == 0:
                logger.info(
                    'Epoch: {:d} cost= {:9f} seq distance= {:9f} train accuracy= {:9f}'
                    .format(epoch + 1, c, seq_distance, accuracy))

            summary_writer.add_summary(summary=summary, global_step=epoch)
            saver.save(sess=sess, save_path=model_save_path, global_step=epoch)

        coord.request_stop()
        coord.join(threads=threads)

    sess.close()

    return
Example #21
0
def train_shadownet():
    """

    :param dataset_dir:
    :param weights_path:
    :return:
    """
    # input_tensor = tf.placeholder(dtype=tf.float32, shape=[config.cfg.TRAIN.BATCH_SIZE, 32, 100, 3],
    #                               name='input_tensor')

    # decode the tf records to get the training data
    decoder = data_utils.TextFeatureIO().reader
    images, labels, imagenames = decoder.read_features(FLAGS.dataset_dir,
                                                       num_epochs=None,
                                                       flag='Train')
    # images_val, labels_val, imagenames_val = decoder.read_features(dataset_dir, num_epochs=None,
    #                                                                flag='Validation')
    inputdata, input_labels, input_imagenames = tf.train.shuffle_batch(
        tensors=[images, labels, imagenames],
        batch_size=config.cfg.TRAIN.BATCH_SIZE,
        capacity=1000 + 2 * config.cfg.TRAIN.BATCH_SIZE,
        min_after_dequeue=100,
        num_threads=1)

    # inputdata_val, input_labels_val, input_imagenames_val = tf.train.shuffle_batch(
    #     tensors=[images_val, labels_val, imagenames_val], batch_size=config.TRAIN.BATCH_SIZE,
    #     capacity=1000 + 2 * config.TRAIN.BATCH_SIZE,
    #     min_after_dequeue=100, num_threads=1)

    inputdata = tf.cast(x=inputdata, dtype=tf.float32)
    phase_tensor = tf.placeholder(dtype=tf.string, shape=None, name='phase')
    accuracy_tensor = tf.placeholder(dtype=tf.float32,
                                     shape=None,
                                     name='accuracy_tensor')

    # initialize the net model
    shadownet = crnn_model.ShadowNet(phase=phase_tensor,
                                     hidden_nums=256,
                                     layers_nums=2,
                                     seq_length=15,
                                     num_classes=config.cfg.TRAIN.CLASSES_NUMS,
                                     rnn_cell_type='lstm')

    with tf.variable_scope('shadow', reuse=False):
        net_out, tensor_dict = shadownet.build_shadownet(inputdata=inputdata)

    cost = tf.reduce_mean(
        tf.nn.ctc_loss(labels=input_labels,
                       inputs=net_out,
                       sequence_length=20 *
                       np.ones(config.cfg.TRAIN.BATCH_SIZE)))

    decoded, log_prob = tf.nn.ctc_beam_search_decoder(
        net_out,
        20 * np.ones(config.cfg.TRAIN.BATCH_SIZE),
        merge_repeated=False)

    sequence_dist = tf.reduce_mean(
        tf.edit_distance(tf.cast(decoded[0], tf.int32), input_labels))

    global_step = tf.Variable(0, name='global_step', trainable=False)

    starter_learning_rate = config.cfg.TRAIN.LEARNING_RATE
    learning_rate = tf.train.exponential_decay(starter_learning_rate,
                                               global_step,
                                               config.cfg.TRAIN.LR_DECAY_STEPS,
                                               config.cfg.TRAIN.LR_DECAY_RATE,
                                               staircase=True)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.AdadeltaOptimizer(
            learning_rate=learning_rate).minimize(loss=cost,
                                                  global_step=global_step)
        # optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9).minimize(
        #     loss=cost, global_step=global_step)

    # Set tf summary
    tboard_save_path = '/data/output/'
    if not ops.exists(tboard_save_path):
        os.makedirs(tboard_save_path)

    visualizor = tensorboard_vis_summary.CNNVisualizer()

    # training过程summary
    train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=cost)
    train_accuracy_scalar = tf.summary.scalar(name='train_accuray',
                                              tensor=accuracy_tensor)
    train_seq_scalar = tf.summary.scalar(name='train_seq_dist',
                                         tensor=sequence_dist)
    train_conv1_image = visualizor.merge_conv_image(
        feature_map=tensor_dict['conv1'], scope='conv1_image')
    train_conv2_image = visualizor.merge_conv_image(
        feature_map=tensor_dict['conv2'], scope='conv2_image')
    train_conv3_image = visualizor.merge_conv_image(
        feature_map=tensor_dict['conv3'], scope='conv3_image')
    train_conv7_image = visualizor.merge_conv_image(
        feature_map=tensor_dict['conv7'], scope='conv7_image')
    lr_scalar = tf.summary.scalar(name='Learning_Rate', tensor=learning_rate)

    weights_tensor_dict = dict()
    for vv in tf.trainable_variables():
        if 'conv' in vv.name:
            weights_tensor_dict[vv.name[:-2]] = vv
    train_weights_hist_dict = visualizor.merge_weights_hist(
        weights_tensor_dict=weights_tensor_dict,
        scope='weights_histogram',
        is_merge=False)

    train_summary_merge_list = [
        train_cost_scalar, train_accuracy_scalar, train_seq_scalar, lr_scalar,
        train_conv1_image, train_conv2_image, train_conv3_image
    ]
    for _, weights_hist in train_weights_hist_dict.items():
        train_summary_merge_list.append(weights_hist)
    train_summary_op_merge = tf.summary.merge(inputs=train_summary_merge_list)

    # validation过程summary
    # val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=cost)
    # val_seq_scalar = tf.summary.scalar(name='val_seq_dist', tensor=sequence_dist)
    # val_accuracy_scalar = tf.summary.scalar(name='val_accuracy', tensor=accuracy_tensor)

    # test_summary_op_merge = tf.summary.merge(inputs=[val_cost_scalar, val_accuracy_scalar,
    #                                                  val_seq_scalar])

    # Set saver configuration
    restore_variable_list = [tmp.name for tmp in tf.trainable_variables()]
    saver = tf.train.Saver()
    model_save_dir = '/data/output'
    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'shadownet_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # Set sess configuration
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = config.cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = config.cfg.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(tboard_save_path)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = config.cfg.TRAIN.EPOCHS

    print('Global configuration is as follows:')
    pprint.pprint(config.cfg)

    with sess.as_default():

        if FLAGS.weights_path is None:
            logger.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            # logger.info('Restore model from last crnn check point{:s}'.format(weights_path))
            # init = tf.global_variables_initializer()
            # sess.run(init)
            # restore_saver = tf.train.Saver(var_list=restore_variable_list)
            # restore_saver.restore(sess=sess, save_path=weights_path)
            logger.info('Restore model from last crnn check point{:s}'.format(
                FLAGS.weights_path))
            saver.restore(sess=sess, save_path=FLAGS.weights_path)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for epoch in range(train_epochs):
            _, c, seq_distance, preds, gt_labels = sess.run(
                [optimizer, cost, sequence_dist, decoded, input_labels],
                feed_dict={phase_tensor: 'train'})

            # calculate the precision
            preds = decoder.sparse_tensor_to_str(preds[0])
            gt_labels = decoder.sparse_tensor_to_str(gt_labels)

            accuracy = []

            for index, gt_label in enumerate(gt_labels):
                pred = preds[index]
                totol_count = len(gt_label)
                correct_count = 0
                try:
                    for i, tmp in enumerate(gt_label):
                        if tmp == pred[i]:
                            correct_count += 1
                except IndexError:
                    continue
                finally:
                    try:
                        accuracy.append(correct_count / totol_count)
                    except ZeroDivisionError:
                        if len(pred) == 0:
                            accuracy.append(1)
                        else:
                            accuracy.append(0)
            accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)

            train_summary = sess.run(train_summary_op_merge,
                                     feed_dict={
                                         accuracy_tensor: accuracy,
                                         phase_tensor: 'train'
                                     })
            summary_writer.add_summary(summary=train_summary,
                                       global_step=epoch)

            if epoch % config.cfg.TRAIN.DISPLAY_STEP == 0:

                logger.info(
                    'Epoch: {:d} cost= {:9f} seq distance= {:9f} train accuracy= {:9f}'
                    .format(epoch + 1, c, seq_distance, accuracy))

            # if epoch % config.cfg.TRAIN.VAL_STEP == 0:
            #     inputdata_value = sess.run(inputdata_val)
            #     val_c, val_seq, val_preds, val_gt_labels = sess.run([
            #         cost, sequence_dist, decoded, input_labels_val],
            #         feed_dict={phase_tensor: 'test',
            #                    input_tensor: inputdata_value})
            #
            #     preds_val = decoder.sparse_tensor_to_str(val_preds[0])
            #     gt_labels_val = decoder.sparse_tensor_to_str(val_gt_labels)
            #
            #     accuracy_val = []
            #
            #     for index, gt_label in enumerate(gt_labels_val):
            #         pred = preds_val[index]
            #         totol_count = len(gt_label)
            #         correct_count = 0
            #         try:
            #             for i, tmp in enumerate(gt_label):
            #                 if tmp == pred[i]:
            #                     correct_count += 1
            #         except IndexError:
            #             continue
            #         finally:
            #             try:
            #                 accuracy_val.append(correct_count / totol_count)
            #             except ZeroDivisionError:
            #                 if len(pred) == 0:
            #                     accuracy_val.append(1)
            #                 else:
            #                     accuracy_val.append(0)
            #
            #     accuracy_val = np.mean(np.array(accuracy_val).astype(np.float32), axis=0)
            #
            #     test_summary = sess.run(test_summary_op_merge,
            #                             feed_dict={accuracy_tensor: accuracy_val,
            #                                        phase_tensor: 'test',
            #                                        input_tensor: inputdata_value})
            #     summary_writer.add_summary(summary=test_summary, global_step=epoch)
            #
            #     logger.info('Epoch: {:d} val_cost= {:9f} val_seq_distance= {:9f} val_accuracy= {:9f}'.format(
            #         epoch + 1, val_c, val_seq, accuracy_val))

            if epoch % 500 == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=epoch)

        coord.request_stop()
        coord.join(threads=threads)

    sess.close()

    return
Example #22
0
def recognize(image_path, weights_path, char_dict_path, ord_map_dict_path,
              is_vis):
    """

    :param image_path:
    :param weights_path:
    :param char_dict_path:
    :param ord_map_dict_path:
    :param is_vis:
    :return:
    """
    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    image = cv2.resize(image,
                       tuple(CFG.ARCH.INPUT_SIZE),
                       interpolation=cv2.INTER_LINEAR)
    image_vis = image
    image = np.array(image, np.float32) / 127.5 - 1.0

    [IMAGE_WIDTH, IMAGE_HEIGHT] = tuple(CFG.ARCH.INPUT_SIZE)
    inputdata = tf.placeholder(
        dtype=tf.float32,
        shape=[1, IMAGE_HEIGHT, IMAGE_WIDTH, CFG.ARCH.INPUT_CHANNELS],
        name='input')

    codec = tf_io_pipline_tools.TextFeatureIO(
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path).reader

    net = crnn_model.ShadowNet(phase='test',
                               hidden_nums=CFG.ARCH.HIDDEN_UNITS,
                               layers_nums=CFG.ARCH.HIDDEN_LAYERS,
                               num_classes=CFG.ARCH.NUM_CLASSES)

    inference_ret = net.inference(inputdata=inputdata,
                                  name='shadow_net',
                                  reuse=False)

    decodes, _ = tf.nn.ctc_beam_search_decoder(
        inputs=inference_ret,
        sequence_length=CFG.ARCH.SEQ_LENGTH * np.ones(1),
        merge_repeated=False)

    # config tf saver
    saver = tf.train.Saver()

    # config tf session
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TEST.TF_ALLOW_GROWTH

    sess = tf.Session(config=sess_config)

    with sess.as_default():

        saver.restore(sess=sess, save_path=weights_path)

        preds = sess.run(decodes, feed_dict={inputdata: [image]})

        preds = codec.sparse_tensor_to_str(preds[0])

        logger.info('Predict image {:s} result {:s}'.format(
            ops.split(image_path)[1], preds[0]))

        if is_vis:
            plt.figure('CRNN Model Demo')
            plt.imshow(image_vis[:, :, (2, 1, 0)])
            plt.show()

    sess.close()

    return
Example #23
0
def recognize(image_path: str,
              weights_path: str,
              cfg: EasyDict,
              is_vis: bool = True,
              num_classes: int = 0):
    """

    :param image_path:
    :param weights_path: Path to stored weights
    :param cfg:
    :param is_vis:
    :param num_classes:
    """

    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    image = cv2.resize(image, tuple(cfg.ARCH.INPUT_SIZE))
    image = np.expand_dims(image, axis=0).astype(np.float32)

    w, h = cfg.ARCH.INPUT_SIZE
    inputdata = tf.placeholder(dtype=tf.float32,
                               shape=[1, h, w, cfg.ARCH.INPUT_CHANNELS],
                               name='input')

    codec = data_utils.TextFeatureIO(
        char_dict_path=ops.join(cfg.PATH.CHAR_DICT_DIR, 'char_dict.json'),
        ord_map_dict_path=ops.join(cfg.PATH.CHAR_DICT_DIR, 'ord_map.json'))

    num_classes = len(
        codec.reader.char_dict) + 1 if num_classes == 0 else num_classes

    net = crnn_model.ShadowNet(phase='Test',
                               hidden_nums=cfg.ARCH.HIDDEN_UNITS,
                               layers_nums=cfg.ARCH.HIDDEN_LAYERS,
                               num_classes=num_classes)

    with tf.variable_scope('shadow'):
        net_out = net.build_shadownet(inputdata=inputdata)

    decodes, _ = tf.nn.ctc_beam_search_decoder(
        inputs=net_out,
        sequence_length=cfg.ARCH.SEQ_LENGTH * np.ones(1),
        merge_repeated=False)

    # config tf session
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = cfg.TRAIN.TF_ALLOW_GROWTH

    # config tf saver
    saver = tf.train.Saver()

    sess = tf.Session(config=sess_config)

    with sess.as_default():

        saver.restore(sess=sess, save_path=weights_path)

        preds = sess.run(decodes, feed_dict={inputdata: image})

        preds = codec.writer.sparse_tensor_to_str(preds[0])

        logger.info('Predict image {:s} label {:s}'.format(
            ops.split(image_path)[1], preds[0]))

        if is_vis:
            plt.figure('CRNN Model Demo')
            plt.imshow(
                cv2.imread(image_path, cv2.IMREAD_COLOR)[:, :, (2, 1, 0)])
            plt.show()

        sess.close()
def test_shadownet(dataset_dir, weights_path, is_vis=False, is_recursive=True):
    """

    :param dataset_dir:
    :param weights_path:
    :param is_vis:
    :param is_recursive:
    :return:
    """
    # Initialize the record decoder
    decoder = data_utils.TextFeatureIO().reader
    images_t, labels_t, imagenames_t = decoder.read_features(ops.join(
        dataset_dir, 'test_feature.tfrecords'),
                                                             num_epochs=None)
    if not is_recursive:
        images_sh, labels_sh, imagenames_sh = tf.train.shuffle_batch(
            tensors=[images_t, labels_t, imagenames_t],
            batch_size=32,
            capacity=1000 + 32 * 2,
            min_after_dequeue=2,
            num_threads=4)
    else:
        images_sh, labels_sh, imagenames_sh = tf.train.batch(
            tensors=[images_t, labels_t, imagenames_t],
            batch_size=32,
            capacity=1000 + 32 * 2,
            num_threads=4)

    images_sh = tf.cast(x=images_sh, dtype=tf.float32)

    # build shadownet
    net = crnn_model.ShadowNet(phase='Test',
                               hidden_nums=256,
                               layers_nums=2,
                               seq_length=25,
                               num_classes=37)

    with tf.variable_scope('shadow'):
        net_out = net.build_shadownet(inputdata=images_sh)

    decoded, _ = tf.nn.ctc_beam_search_decoder(net_out,
                                               25 * np.ones(32),
                                               merge_repeated=False)

    # config tf session
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = config.cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = config.cfg.TRAIN.TF_ALLOW_GROWTH

    # config tf saver
    saver = tf.train.Saver()

    sess = tf.Session(config=sess_config)

    test_sample_count = 0
    for record in tf.python_io.tf_record_iterator(
            ops.join(dataset_dir, 'test_feature.tfrecords')):
        test_sample_count += 1
    loops_nums = int(math.ceil(test_sample_count / 32))
    # loops_nums = 100

    with sess.as_default():

        # restore the model weights
        saver.restore(sess=sess, save_path=weights_path)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        print('Start predicting ......')
        if not is_recursive:
            predictions, images, labels, imagenames = sess.run(
                [decoded, images_sh, labels_sh, imagenames_sh])
            imagenames = np.reshape(imagenames, newshape=imagenames.shape[0])
            imagenames = [tmp.decode('utf-8') for tmp in imagenames]
            preds_res = decoder.sparse_tensor_to_str(predictions[0])
            gt_res = decoder.sparse_tensor_to_str(labels)

            accuracy = []

            for index, gt_label in enumerate(gt_res):
                pred = preds_res[index]
                totol_count = len(gt_label)
                correct_count = 0
                try:
                    for i, tmp in enumerate(gt_label):
                        if tmp == pred[i]:
                            correct_count += 1
                except IndexError:
                    continue
                finally:
                    try:
                        accuracy.append(correct_count / totol_count)
                    except ZeroDivisionError:
                        if len(pred) == 0:
                            accuracy.append(1)
                        else:
                            accuracy.append(0)

            accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
            print('Mean test accuracy is {:5f}'.format(accuracy))

            for index, image in enumerate(images):
                print(
                    'Predict {:s} image with gt label: {:s} **** predict label: {:s}'
                    .format(imagenames[index], gt_res[index],
                            preds_res[index]))
                if is_vis:
                    plt.imshow(image[:, :, (2, 1, 0)])
                    plt.show()
        else:
            accuracy = []
            for epoch in range(loops_nums):
                predictions, images, labels, imagenames = sess.run(
                    [decoded, images_sh, labels_sh, imagenames_sh])
                imagenames = np.reshape(imagenames,
                                        newshape=imagenames.shape[0])
                imagenames = [tmp.decode('utf-8') for tmp in imagenames]
                preds_res = decoder.sparse_tensor_to_str(predictions[0])
                gt_res = decoder.sparse_tensor_to_str(labels)

                for index, gt_label in enumerate(gt_res):
                    pred = preds_res[index]
                    totol_count = len(gt_label)
                    correct_count = 0
                    try:
                        for i, tmp in enumerate(gt_label):
                            if tmp == pred[i]:
                                correct_count += 1
                    except IndexError:
                        continue
                    finally:
                        try:
                            accuracy.append(correct_count / totol_count)
                        except ZeroDivisionError:
                            if len(pred) == 0:
                                accuracy.append(1)
                            else:
                                accuracy.append(0)

                for index, image in enumerate(images):
                    print(
                        'Predict {:s} image with gt label: {:s} **** predict label: {:s}'
                        .format(imagenames[index], gt_res[index],
                                preds_res[index]))
                    # if is_vis:
                    #     plt.imshow(image[:, :, (2, 1, 0)])
                    #     plt.show()

            accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
            print('Test accuracy is {:5f}'.format(accuracy))

        coord.request_stop()
        coord.join(threads=threads)

    sess.close()
    return
Example #25
0
def train_shadownet(dataset_dir,
                    weights_path,
                    char_dict_path,
                    ord_map_dict_path,
                    need_decode=False):
    """

    :param dataset_dir:
    :param weights_path:
    :param char_dict_path:
    :param ord_map_dict_path:
    :param need_decode:
    :return:
    """
    # prepare dataset
    train_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='train')
    val_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='val')
    train_images, train_labels, train_images_paths = train_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE, num_epochs=1)
    val_images, val_labels, val_images_paths = val_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE, num_epochs=1)

    # declare crnn net
    shadownet = crnn_model.ShadowNet(phase='train',
                                     hidden_nums=CFG.ARCH.HIDDEN_UNITS,
                                     layers_nums=CFG.ARCH.HIDDEN_LAYERS,
                                     num_classes=CFG.ARCH.NUM_CLASSES)
    shadownet_val = crnn_model.ShadowNet(phase='test',
                                         hidden_nums=CFG.ARCH.HIDDEN_UNITS,
                                         layers_nums=CFG.ARCH.HIDDEN_LAYERS,
                                         num_classes=CFG.ARCH.NUM_CLASSES)

    # set up decoder
    decoder = tf_io_pipline_tools.TextFeatureIO(
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path).reader

    # set up training graph
    with tf.device('/gpu:1'):

        # compute loss and seq distance
        train_inference_ret, train_ctc_loss = shadownet.compute_loss(
            inputdata=train_images,
            labels=train_labels,
            name='shadow_net',
            reuse=False)
        val_inference_ret, val_ctc_loss = shadownet_val.compute_loss(
            inputdata=val_images,
            labels=val_labels,
            name='shadow_net',
            reuse=True)

        train_decoded, train_log_prob = tf.nn.ctc_beam_search_decoder(
            train_inference_ret,
            CFG.ARCH.SEQ_LENGTH * np.ones(CFG.TRAIN.BATCH_SIZE),
            merge_repeated=False)
        val_decoded, val_log_prob = tf.nn.ctc_beam_search_decoder(
            val_inference_ret,
            CFG.ARCH.SEQ_LENGTH * np.ones(CFG.TRAIN.BATCH_SIZE),
            merge_repeated=False)

        train_sequence_dist = tf.reduce_mean(tf.edit_distance(
            tf.cast(train_decoded[0], tf.int32), train_labels),
                                             name='train_edit_distance')
        val_sequence_dist = tf.reduce_mean(tf.edit_distance(
            tf.cast(val_decoded[0], tf.int32), val_labels),
                                           name='val_edit_distance')

        # set learning rate
        global_step = tf.Variable(0, name='global_step', trainable=False)
        learning_rate = tf.train.exponential_decay(
            learning_rate=CFG.TRAIN.LEARNING_RATE,
            global_step=global_step,
            decay_steps=CFG.TRAIN.LR_DECAY_STEPS,
            decay_rate=CFG.TRAIN.LR_DECAY_RATE,
            staircase=CFG.TRAIN.LR_STAIRCASE)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                                   momentum=0.9).minimize(
                                                       loss=train_ctc_loss,
                                                       global_step=global_step)

    # Set tf summary
    tboard_save_dir = 'tboard/crnn_syn90k'
    os.makedirs(tboard_save_dir, exist_ok=True)
    tf.summary.scalar(name='train_ctc_loss', tensor=train_ctc_loss)
    tf.summary.scalar(name='val_ctc_loss', tensor=val_ctc_loss)
    tf.summary.scalar(name='learning_rate', tensor=learning_rate)

    if need_decode:
        tf.summary.scalar(name='train_seq_distance',
                          tensor=train_sequence_dist)
        tf.summary.scalar(name='val_seq_distance', tensor=val_sequence_dist)

    merge_summary_op = tf.summary.merge_all()

    # Set saver configuration
    saver = tf.train.Saver()
    model_save_dir = 'model/crnn_syn90k'
    os.makedirs(model_save_dir, exist_ok=True)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'shadownet_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH

    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(tboard_save_dir)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    with sess.as_default():
        if weights_path is None:
            logger.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            logger.info('Restore model from {:s}'.format(weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        patience_counter = 1
        cost_history = [np.inf]
        for epoch in range(train_epochs):

            # setup early stopping
            if epoch > 1 and CFG.TRAIN.EARLY_STOPPING:
                # We always compare to the first point where cost didn't improve
                if cost_history[-1 - patience_counter] - cost_history[
                        -1] > CFG.TRAIN.PATIENCE_DELTA:
                    patience_counter = 1
                else:
                    patience_counter += 1
                if patience_counter > CFG.TRAIN.PATIENCE_EPOCHS:
                    logger.info(
                        "Cost didn't improve beyond {:f} for {:d} epochs, stopping early."
                        .format(CFG.TRAIN.PATIENCE_DELTA, patience_counter))
                    break

            if need_decode and epoch % 500 == 0:
                # train part
                _, train_ctc_loss_value, train_seq_dist_value, \
                train_predictions, train_labels, merge_summary_value = sess.run(
                    [optimizer, train_ctc_loss, train_sequence_dist,
                     train_decoded, train_labels, merge_summary_op])

                train_labels = decoder.sparse_tensor_to_str(train_labels)
                train_predictions = decoder.sparse_tensor_to_str(
                    train_predictions[0])
                avg_train_accuracy = evaluation_tools.compute_accuracy(
                    train_labels, train_predictions)

                if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                    logger.info(
                        'Epoch_Train: {:d} cost= {:9f} seq distance= {:9f} train accuracy= {:9f}'
                        .format(epoch + 1, train_ctc_loss_value,
                                train_seq_dist_value, avg_train_accuracy))

                # validation part
                val_ctc_loss_value, val_seq_dist_value, \
                val_predictions, val_labels = sess.run(
                    [val_ctc_loss, val_sequence_dist, val_decoded, val_labels])

                val_labels = decoder.sparse_tensor_to_str(val_labels)
                val_predictions = decoder.sparse_tensor_to_str(
                    val_predictions[0])
                avg_val_accuracy = evaluation_tools.compute_accuracy(
                    val_labels, val_predictions)

                if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                    logger.info(
                        'Epoch_Val: {:d} cost= {:9f} seq distance= {:9f} train accuracy= {:9f}'
                        .format(epoch + 1, val_ctc_loss_value,
                                val_seq_dist_value, avg_val_accuracy))
            else:
                _, train_ctc_loss_value, merge_summary_value = sess.run(
                    [optimizer, train_ctc_loss, merge_summary_op])

                if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                    logger.info('Epoch_Train: {:d} cost= {:9f}'.format(
                        epoch + 1, train_ctc_loss_value))

            # record history train ctc loss
            cost_history.append(train_ctc_loss_value)
            # add training sumary
            summary_writer.add_summary(summary=merge_summary_value,
                                       global_step=epoch)

            if epoch % 2000 == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=epoch)

    return np.array(cost_history[1:])  # Don't return the first np.inf