Exemplo n.º 1
0
    def __init__(self,
                 ctpn_weight_path,
                 densenet_weight_path,
                 dict_path,
                 ctpn_config_path=None,
                 densenet_config_path=None):
        """

        :param ctpn_weight_path:    CTPN 模型权重文件路径
        :param densenet_weight_path: Densenet 模型权重文件路径
        :param dict_path:           字典文件路径
        :param ctpn_config_path:    CTPN 模型配置文件路径
        :param densenet_config_path: Densenet 模型配置文件路径
        """

        self.id_to_char = load_dict(dict_path, encoding="utf-8")

        # 初始化CTPN模型
        if ctpn_config_path is not None:
            ctpn_config = CTPN.load_config(ctpn_config_path)
            ctpn_config["weight_path"] = ctpn_weight_path
            self.ctpn = CTPN(**ctpn_config)
        else:
            self.ctpn = CTPN()

        # 初始化Densenet 模型
        if densenet_config_path is not None:
            densenet_config = DenseNetOCR.load_config(densenet_config_path)
            densenet_config["weight_path"] = densenet_weight_path
            self.ocr = DenseNetOCR(**densenet_config)
        else:
            self.ocr = DenseNetOCR(num_classes=len(self.id_to_char))
                        default=None)
    parser.add_argument("--save_weights_file_path", help="保存模型训练权重文件位置",
                        default=r'model/cv_weights-ctpnlstm-{epoch:02d}.hdf5')

    args = parser.parse_args()
    #movefile(args.anno_dir)

    K.set_session(get_session(0.8))
    config = CTPN.load_config(args.config_file_path)

    weights_file_path = args.weights_file_path
    if weights_file_path is not None:
        config["weight_path"] = weights_file_path
    config['num_gpu'] = args.gpus

    ctpn = CTPN(**config)

    save_weigths_file_path = args.save_weights_file_path

    if save_weigths_file_path is None:
        try:
            if not os.path.exists("model"):
                os.makedirs("model")
            save_weigths_file_path = "model/weights-ctpnlstm-{epoch:02d}.hdf5"
        except OSError:
            print('Error: Creating directory. ' + "model")

    train_data_loader = DataLoader(args.anno_dir, args.images_dir)
    valid_data_loader = DataLoader(valid_path, args.images_dir)

    checkpoint = SingleModelCK(save_weigths_file_path, model=ctpn.parallel_model, save_weights_only=False)