示例#1
0
def train_shadownet(dataset_dir, weights_path=None):

    print("读取训练数据")
    images, labels, imagenames = data_utils.read_features(
        dataset_dir, num_epochs=None)  #读取.tfrecords文件
    #创建一个乱序序列用于训练
    inputdata, input_labels, input_imagenames = tf.train.shuffle_batch(
        tensors=[images, labels, imagenames],
        batch_size=32,
        capacity=1000 + 2 * 32,
        min_after_dequeue=100,
        num_threads=1)
    inputdata = tf.cast(x=inputdata, dtype=tf.float32)  #占位

    print("初始化网络")  # 在这里声明了创建网络的类
    shadownet = crnn_model.ShadowNet(phase='Train',
                                     hidden_nums=256,
                                     layers_nums=2,
                                     seq_length=25,
                                     num_classes=37)

    with tf.variable_scope('shadow',
                           reuse=False):  #通过tf.variable_scope生成一个上下文管理器
        net_out = shadownet.build_shadownet(inputdata=inputdata)  #创建网络,指定输入数据

    cost = tf.reduce_mean(
        tf.nn.ctc_loss(labels=input_labels,
                       inputs=net_out,
                       sequence_length=25 * np.ones(32)))  #按照设定的维度求张量平均值
    decoded, log_prob = tf.nn.ctc_beam_search_decoder(
        net_out, 25 * np.ones(32), merge_repeated=False)  #对数据解码
    sequence_dist = tf.reduce_mean(
        tf.edit_distance(tf.cast(decoded[0], tf.int32),
                         input_labels))  #按照设定的维度求张量平均值
    global_step = tf.Variable(0, name='global_step', trainable=False)  #初始化图变量
    starter_learning_rate = config.cfg.TRAIN.LEARNING_RATE  #设定初始学习速率
    learning_rate = tf.train.exponential_decay(starter_learning_rate,
                                               global_step,
                                               config.cfg.TRAIN.LR_DECAY_STEPS,
                                               config.cfg.TRAIN.LR_DECAY_RATE,
                                               staircase=True)  #按照指数衰减方式改变学习速率
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  #设定为全局变量

    with tf.control_dependencies(update_ops):  #Adadelta算法的优化器
        optimizer = tf.train.AdadeltaOptimizer(
            learning_rate=learning_rate).minimize(loss=cost,
                                                  global_step=global_step)

    # 设置tensorflow的模型管理模式
    tboard_save_path = 'tboard/shadownet'
    if not ops.exists(tboard_save_path):
        os.makedirs(tboard_save_path)
    tf.summary.scalar(name='Cost', tensor=cost)
    tf.summary.scalar(name='Learning_Rate', tensor=learning_rate)
    tf.summary.scalar(name='Seq_Dist', tensor=sequence_dist)
    merge_summary_op = tf.summary.merge_all()  #自动管理模式,导入之前已经保存的模型继续训练

    # 设置模型保存路径
    saver = tf.train.Saver()
    model_save_dir = 'model/shadownet'
    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'shadownet_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # gpu参数
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = config.cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = config.cfg.TRAIN.TF_ALLOW_GROWTH

    sess = tf.Session(config=sess_config)  #创建图运算

    summary_writer = tf.summary.FileWriter(tboard_save_path)
    summary_writer.add_graph(sess.graph)

    # 迭代次数
    train_epochs = config.cfg.TRAIN.EPOCHS

    print("开始训练")
    with sess.as_default():
        if weights_path is None:
            print('完全重新开始训练')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            print('在之前的模型:' + 'weights_path' + '上继续训练')
            saver.restore(sess=sess, save_path=weights_path)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for epoch in range(train_epochs):
            _, c, seq_distance, preds, gt_labels, summary = sess.run([
                optimizer, cost, sequence_dist, decoded, input_labels,
                merge_summary_op
            ])

            # calculate the precision
            preds = data_utils.sparse_tensor_to_str(preds[0])
            gt_labels = data_utils.sparse_tensor_to_str(gt_labels)

            accuracy = []

            for index, gt_label in enumerate(gt_labels):
                pred = preds[index]
                totol_count = len(gt_label)
                correct_count = 0
                try:
                    for i, tmp in enumerate(gt_label):
                        if tmp == pred[i]:
                            correct_count += 1
                except IndexError:
                    continue
                finally:
                    try:
                        accuracy.append(correct_count / totol_count)
                    except ZeroDivisionError:
                        if len(pred) == 0:
                            accuracy.append(1)
                        else:
                            accuracy.append(0)
            accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
            #
            if epoch % config.cfg.TRAIN.DISPLAY_STEP == 0:
                print(
                    'Epoch: %d cost= %f seq distance= %f train accuracy= %f' %
                    (epoch + 1, c, seq_distance, accuracy))

            summary_writer.add_summary(summary=summary, global_step=epoch)
            saver.save(sess=sess, save_path=model_save_path, global_step=epoch)

        coord.request_stop()
        coord.join(threads=threads)

    sess.close()

    return
示例#2
0
def test_shadownet(dataset_dir, weights_path, is_vis=False, is_recursive=True):

    images_t, labels_t, imagenames_t = data_utils.read_features(
        dataset_dir, num_epochs=None)  #读取.tfrecords文件

    if not is_recursive:
        #如果设置is_recursive为flase,则创建一个乱序的数据序列。
        #capacity读取数据范围;min_after_dequeue越大,数据越乱
        images_sh, labels_sh, imagenames_sh = tf.train.shuffle_batch(
            tensors=[images_t, labels_t, imagenames_t],
            batch_size=32,
            capacity=1000 + 32 * 2,
            min_after_dequeue=2,
            num_threads=4)
    else:
        #如果设置is_recursive为True,则不打乱数据顺序
        images_sh, labels_sh, imagenames_sh = tf.train.batch(
            tensors=[images_t, labels_t, imagenames_t],
            batch_size=32,
            capacity=1000 + 32 * 2,
            num_threads=4)

    images_sh = tf.cast(x=images_sh, dtype=tf.float32)  #将图像数据类型转为float32

    # 在这里声明了创建网络的类
    net = crnn_model.ShadowNet(phase='Test',
                               hidden_nums=256,
                               layers_nums=2,
                               seq_length=25,
                               num_classes=37)

    with tf.variable_scope('shadow'):  #通过tf.variable_scope生成一个上下文管理器
        net_out = net.build_shadownet(inputdata=images_sh)  #创建网络,指定输入数据

    decoded, _ = tf.nn.ctc_beam_search_decoder(net_out,
                                               25 * np.ones(32),
                                               merge_repeated=False)  #对数据解码

    # 设置session配置参数
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = config.cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = config.cfg.TRAIN.TF_ALLOW_GROWTH

    # 初始化保存数据
    saver = tf.train.Saver()

    #创建图运算
    sess = tf.Session(config=sess_config)

    test_sample_count = 0
    for record in tf.python_io.tf_record_iterator(dataset_dir):
        test_sample_count += 1
    loops_nums = int(math.ceil(test_sample_count / 32))

    with sess.as_default():  #创建图计算的默认会话,当上下文管理器关闭时,这个对话不会关闭

        # 加载网络权重
        saver.restore(sess=sess, save_path=weights_path)

        coord = tf.train.Coordinator()  #创建一个协调器,管理线程
        threads = tf.train.start_queue_runners(
            sess=sess, coord=coord)  #启动QueueRunner, 此时文件名队列已经进队

        print('开始预测文字......')
        if not is_recursive:  #如果设置is_recursive为flase,则创建一个乱序的数据序列。,和最开始创建数据系列方式保持一致
            predictions, images, labels, imagenames = sess.run(
                [decoded, images_sh, labels_sh, imagenames_sh])  #运行图计算
            imagenames = np.reshape(imagenames, newshape=imagenames.shape[0])
            imagenames = [tmp.decode('utf-8') for tmp in imagenames]
            preds_res = data_utils.sparse_tensor_to_str(
                predictions[0])  #获取的预测文字结果
            gt_res = data_utils.sparse_tensor_to_str(labels)  #真实的结果
            accuracy = []  #用来保存准确率

            for index, gt_label in enumerate(
                    gt_res):  #enumerate方式同时获取来一个list的索引和对应元素
                pred = preds_res[index]
                totol_count = len(gt_label)
                correct_count = 0
                try:
                    for i, tmp in enumerate(
                            gt_label):  #这里逐项对比预测结果和真实结果,记录准确结果个数
                        if tmp == pred[i]:
                            correct_count += 1
                except IndexError:
                    continue
                finally:
                    try:
                        accuracy.append(correct_count /
                                        totol_count)  #错误的/全部的几位准确率
                    except ZeroDivisionError:
                        if len(pred) == 0:
                            accuracy.append(1)
                        else:
                            accuracy.append(0)

            accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
            print(' test accuracy 为 %f' % (accuracy))

            for index, image in enumerate(images):
                print('预测图片 %s 准确的label为: %s **** 预测的 label: %s' %
                      (imagenames[index], gt_res[index], preds_res[index]))
                if is_vis:
                    plt.imshow(image[:, :, (2, 1, 0)])
                    plt.show()
        else:  #这里是非乱序获取数据序列的,和上面的if对应
            accuracy = []
            for epoch in range(loops_nums):
                predictions, images, labels, imagenames = sess.run(
                    [decoded, images_sh, labels_sh, imagenames_sh])
                imagenames = np.reshape(imagenames,
                                        newshape=imagenames.shape[0])
                imagenames = [tmp.decode('utf-8') for tmp in imagenames]
                preds_res = data_utils.sparse_tensor_to_str(predictions[0])
                gt_res = data_utils.sparse_tensor_to_str(labels)

                for index, gt_label in enumerate(gt_res):
                    pred = preds_res[index]
                    totol_count = len(gt_label)
                    correct_count = 0
                    try:
                        for i, tmp in enumerate(gt_label):
                            if tmp == pred[i]:
                                correct_count += 1
                    except IndexError:
                        continue
                    finally:
                        try:
                            accuracy.append(correct_count / totol_count)
                        except ZeroDivisionError:
                            if len(pred) == 0:
                                accuracy.append(1)
                            else:
                                accuracy.append(0)

                for index, image in enumerate(images):
                    print('预测图片 %s 准确的label为: %s **** 预测的label: %s' %
                          (imagenames[index], gt_res[index], preds_res[index]))
                    if is_vis:  #如果在recognize()中,将is_vis=True,则显示图片
                        plt.imshow(image[:, :, (2, 1, 0)])
                        plt.show()

            accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
            print('Test accuracy is %f' % (accuracy))

        coord.request_stop()
        coord.join(threads=threads)

    sess.close()
    return
示例#3
0
def test_shadownet(dataset_dir, weights_path, is_vis=False, is_recursive=True):

    images_t, labels_t, imagenames_t = data_utils.read_features(
        dataset_dir, num_epochs=None)

    if not is_recursive:
        images_sh, labels_sh, imagenames_sh = tf.train.shuffle_batch(
            tensors=[images_t, labels_t, imagenames_t],
            batch_size=32,
            capacity=1000 + 32 * 2,
            min_after_dequeue=2,
            num_threads=4)
    else:
        images_sh, labels_sh, imagenames_sh = tf.train.batch(
            tensors=[images_t, labels_t, imagenames_t],
            batch_size=32,
            capacity=1000 + 32 * 2,
            num_threads=4)

    images_sh = tf.cast(x=images_sh, dtype=tf.float32)

    net = crnn_model.ShadowNet(phase='Test',
                               hidden_nums=256,
                               layers_nums=2,
                               seq_length=25,
                               num_classes=37)

    with tf.variable_scope('shadow'):
        net_out = net.build_shadownet(inputdata=images_sh)

    decoded, _ = tf.nn.ctc_beam_search_decoder(net_out,
                                               25 * np.ones(32),
                                               merge_repeated=False)
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = config.cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = config.cfg.TRAIN.TF_ALLOW_GROWTH
    saver = tf.train.Saver()
    sess = tf.Session(config=sess_config)

    test_sample_count = 0
    for record in tf.python_io.tf_record_iterator(dataset_dir):
        test_sample_count += 1
    loops_nums = int(math.ceil(test_sample_count / 32))

    with sess.as_default():
        saver.restore(sess=sess, save_path=weights_path)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        print('Start predicting......')
        if not is_recursive:
            predictions, images, labels, imagenames = sess.run(
                [decoded, images_sh, labels_sh, imagenames_sh])
            imagenames = np.reshape(imagenames, newshape=imagenames.shape[0])
            imagenames = [tmp.decode('utf-8') for tmp in imagenames]
            preds_res = data_utils.sparse_tensor_to_str(predictions[0])
            gt_res = data_utils.sparse_tensor_to_str(labels)
            accuracy = []

            for index, gt_label in enumerate(gt_res):
                pred = preds_res[index]
                total_count = len(gt_label)
                correct_count = 0
                try:
                    for i, tmp in enumerate(gt_label):
                        if tmp == pred[i]:
                            correct_count += 1
                except IndexError:
                    continue
                finally:
                    try:
                        accuracy.append(correct_count / total_count)
                    except ZeroDivisionError:
                        if len(pred) == 0:
                            accuracy.append(1)
                        else:
                            accuracy.append(0)

            accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
            print(' test accuracy 为 %f' % (accuracy))

            for index, image in enumerate(images):
                print(
                    'Predicting image: %s \n \t Correct label is:   %s \n \t Predicted label is: %s'
                    % (imagenames[index], gt_res[index], preds_res[index]))
                if is_vis:
                    plt.imshow(image[:, :, (2, 1, 0)])
                    plt.show()
        else:
            accuracy = []
            for epoch in range(loops_nums):
                predictions, images, labels, imagenames = sess.run(
                    [decoded, images_sh, labels_sh, imagenames_sh])
                imagenames = np.reshape(imagenames,
                                        newshape=imagenames.shape[0])
                imagenames = [tmp.decode('utf-8') for tmp in imagenames]
                preds_res = data_utils.sparse_tensor_to_str(predictions[0])
                gt_res = data_utils.sparse_tensor_to_str(labels)

                for index, gt_label in enumerate(gt_res):
                    pred = preds_res[index]
                    total_count = len(gt_label)
                    correct_count = 0
                    try:
                        for i, tmp in enumerate(gt_label):
                            if tmp == pred[i]:
                                correct_count += 1
                    except IndexError:
                        continue
                    finally:
                        try:
                            accuracy.append(correct_count / total_count)
                        except ZeroDivisionError:
                            if len(pred) == 0:
                                accuracy.append(1)
                            else:
                                accuracy.append(0)

                for index, image in enumerate(images):
                    print(
                        'Predicting image: %s \n \t Correct label is:   %s \n \t Predicted label is: %s'
                        % (imagenames[index], gt_res[index], preds_res[index]))
                    if is_vis:
                        plt.imshow(image[:, :, (2, 1, 0)])
                        plt.show()

            accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
            print('Test accuracy is %f' % (accuracy))

        coord.request_stop()
        coord.join(threads=threads)

    sess.close()
    return