def model_weights_path(weights, segment_task, model_struc="densenet_gru"):
    _, ckpt_dir, _ = get_segment_task_path(segment_task)

    if isinstance(weights, str):
        if os.path.exists(weights):
            weights_path = weights
        elif os.path.exists(os.path.join(ckpt_dir, weights)):
            weights_name = weights
            weights_path = os.path.join(ckpt_dir, weights_name)
        else:
            files = os.listdir(ckpt_dir)
            if len(files) == 1 and files[0].endswith(".h5"):
                weights_path = os.path.join(ckpt_dir, files[0])
            else:
                weights_path = os.path.join(
                    ckpt_dir,
                    segment_task + "_segment_" + model_struc + "_finished.h5")
                assert os.path.exists(weights_path)
    else:
        assert isinstance(weights, int)
        weights_id = "{:04d}".format(weights)
        weights_path = os.path.join(
            ckpt_dir, segment_task + "_segment_" + model_struc + "_" +
            weights_id + ".h5")
        assert os.path.exists(weights_path)

    return weights_path
Example #2
0
def main(img_path,
         dest_dir,
         segment_task="book_page",
         text_type="horizontal",
         model_struc="densenet_gru",
         weights_path=""):
    check_or_makedirs(dest_dir)
    K.set_learning_phase(False)
    _, fixed_shape, feat_stride = get_segment_task_params(segment_task)
    _, ckpt_dir, logs_dir = get_segment_task_path(segment_task)
    if not os.path.exists(weights_path):
        weights_path = os.path.join(ckpt_dir,
                                    model_struc + "_ctpn_finished.h5")
        assert os.path.exists(weights_path)

    # 加载模型
    segment_model = work_net(stage="predict",
                             segment_task=segment_task,
                             text_type=text_type,
                             model_struc=model_struc)
    segment_model.load_weights(weights_path, by_name=True)
    print("\nLoad model weights from %s\n" % weights_path)
    # ctpn_model.summary()

    count = 0
    for raw_np_img, img_name in load_images(img_path):
        count += 1

        np_img, _, scale_ratio = adjust_img_to_fixed_shape(
            raw_np_img,
            fixed_shape=fixed_shape,
            feat_stride=feat_stride,
            segment_task=segment_task,
            text_type=text_type)
        batch_images = np_img[np.newaxis, :, :, :]

        split_positions, scores = segment_model.predict(x=batch_images)  # 模型预测

        text_type = text_type[0].lower()
        if (segment_task,
                text_type) in (("book_page", "h"), ("double_line", "h"),
                               ("text_line", "v"), ("mix_line", "v")):
            _, split_positions = restore_original_angle(
                np_img=None, pred_split_positions=split_positions)

        split_positions = split_positions / scale_ratio
        image = visualize.draw_split_lines(raw_np_img, split_positions,
                                           scores)  # 可视化

        PIL_img = Image.fromarray(image)
        dest_path = os.path.join(dest_dir,
                                 os.path.splitext(img_name)[0] + ".jpg")
        PIL_img.save(dest_path, format="jpeg")
        print(count, "Finished: " + dest_path)
Example #3
0
        if (segment_task,
                text_type) in (("book_page", "h"), ("double_line", "h"),
                               ("text_line", "v"), ("mix_line", "v")):
            _, split_positions = restore_original_angle(
                np_img=None, pred_split_positions=split_positions)

        split_positions = split_positions / scale_ratio
        image = visualize.draw_split_lines(raw_np_img, split_positions,
                                           scores)  # 可视化

        PIL_img = Image.fromarray(image)
        dest_path = os.path.join(dest_dir,
                                 os.path.splitext(img_name)[0] + ".jpg")
        PIL_img.save(dest_path, format="jpeg")
        print(count, "Finished: " + dest_path)


if __name__ == '__main__':
    # parse = argparse.ArgumentParser()
    # parse.add_argument("--image_path", type=str, default="", help="image path")
    # parse.add_argument("--dest_path", type=str, default="", help="detected result path")
    # parse.add_argument("--text_type", type=str, default="vertical", help="horizontal or vertical text")
    # parse.add_argument("--weight_path", type=str, default="", help="model weight path")
    # parse.add_argument("--use_side_refine", type=int, default=1, help="1: use side refine; 0 not use")
    # args = parse.parse_args(sys.argv[1:])

    segment_task = "book_page"
    root_dir, _, _ = get_segment_task_path(segment_task)
    dest_dir = os.path.join(root_dir, "samples")
    main(img_path, dest_dir, segment_task="book_page", text_type="horizontal")
Example #4
0
def main(data_file,
         src_type,
         text_type,
         segment_task,
         epochs,
         init_epochs=0,
         model_struc="densenet_gru",
         weights_path=""):
    _, ckpt_dir, logs_dir = get_segment_task_path(segment_task)
    tf_config()
    K.set_learning_phase(True)

    # 加载模型
    train_model, val_model = work_net(stage="train",
                                      segment_task=segment_task,
                                      text_type=text_type,
                                      model_struc=model_struc)
    compile(train_model,
            loss_names=['segment_class_loss', 'segment_regress_loss'])

    # 增加度量汇总
    total_acc, pos_acc, neg_acc = train_model.get_layer('accuracy').output
    num_pos, num_neg = train_model.get_layer('segment_target').output[4:]
    add_metrics(
        train_model,
        metric_name_list=[
            'total_acc', 'pos_acc', 'neg_acc', 'num_pos', 'num_neg'
        ],
        metric_val_list=[total_acc, pos_acc, neg_acc, num_pos, num_neg])
    train_model.summary()

    # for layer in train_model.layers:
    #     print(layer.name, " trainable: ", layer.trainable)

    # load model
    load_path = os.path.join(
        ckpt_dir, segment_task + "_segment_" + model_struc +
        "_{:04d}.h5".format(init_epochs))
    weights_path = weights_path if os.path.exists(weights_path) else load_path
    if os.path.exists(weights_path):
        train_model.load_weights(weights_path, by_name=True)
        print("\nLoad model weights from %s\n" % weights_path)

    training_generator, validation_generator = data_generator(
        data_file=data_file,
        segment_task=segment_task,
        src_type=src_type,
        text_type=text_type)

    summary_writer = tf.summary.create_file_writer(logs_dir)
    callbacks = get_callbacks(segment_task, model_struc)
    steps_per_epoch = 200
    for epoch in range(init_epochs, init_epochs + epochs, 5):
        # 开始训练
        train_model.fit_generator(generator=training_generator,
                                  steps_per_epoch=steps_per_epoch,
                                  epochs=epoch + 5,
                                  initial_epoch=epoch,
                                  verbose=1,
                                  validation_data=validation_generator,
                                  validation_steps=10,
                                  callbacks=callbacks,
                                  max_queue_size=100)

        for i in range(5):  # 汇总图片
            x = next(validation_generator
                     ) if src_type == "images" else validation_generator
            summary_images = val_model.predict_on_batch(x=x).numpy()
            with summary_writer.as_default():
                tf.summary.image("image_%d" % i,
                                 summary_images.astype("uint8"),
                                 step=epoch * steps_per_epoch,
                                 max_outputs=20)
        summary_writer.flush()

    summary_writer.close()
    train_model.save_weights(
        os.path.join(ckpt_dir, segment_task + "_segment_" + model_struc +
                     "_finished.h5"))  # 保存模型

    print("Done !")