示例#1
0
def load_model(model_path, with_gpu):
    logger.info("Loading checkpoint: {} ...".format(model_path))
    checkpoints = torch.load(model_path)
    if not checkpoints:
        raise RuntimeError('No checkpoint found.')
    config = checkpoints['config']
    state_dict = checkpoints['state_dict']
    model = FOTSModel(config)
    model = torch.nn.DataParallel(model)
    model.load_state_dict(state_dict)
    if with_gpu:
        model = model.cuda()
    return model
示例#2
0
 def __init__(self, config):
     """
     负责服务初始化
     """
     self.model = FOTSModel(config, False)
     self.model.eval()
     self.config = config
     self.model.load_state_dict(
         torch.load(config['model_path'])['state_dict'])
     self.label_converter = strLabelConverter(
         getattr(common_str, self.config['model']['keys']))
     if config['cuda']:
         self.model.to(torch.device("cuda:0"))
     print('init finish')
示例#3
0
def load_model(model_path, with_gpu):
    logger.info("Loading checkpoint: {} ...".format(model_path))
    checkpoints = torch.load(model_path)
    if not checkpoints:
        raise RuntimeError('No checkpoint found.')
    config = checkpoints['config']
    state_dict = checkpoints['state_dict']

    model = FOTSModel(config)
    model.load_state_dict(state_dict)

    if with_gpu:
        model.to(torch.device("cuda:0"))
        model.parallelize()

    model.eval()
    return model
示例#4
0
class OCRServer(base_pb2_grpc.OCRServicer):
    def __init__(self, config):
        """
        负责服务初始化
        """
        self.model = FOTSModel(config, False)
        self.model.eval()
        self.config = config
        self.model.load_state_dict(
            torch.load(config['model_path'])['state_dict'])
        self.label_converter = strLabelConverter(
            getattr(common_str, self.config['model']['keys']))
        if config['cuda']:
            self.model.to(torch.device("cuda:0"))
        print('init finish')

    def detect(self, request, context):
        to_return = {'mode': 'detect'}
        return base_pb2.OCRResponse(message=json.dumps(to_return))

    def recognize(self, request, context):
        to_return = {'mode': 'recognize'}
        return base_pb2.OCRResponse(message=json.dumps(to_return))

    def _area_by_shoelace(self, points):
        x, y = [_[0] for _ in points], [_[1] for _ in points]
        return abs(
            sum(i * j for i, j in zip(x, y[1:] + y[:1])) -
            sum(i * j for i, j in zip(x[1:] + x[:1], y))) / 2

    def detect_and_recognize(self, request, context):
        to_return = {'mode': 'detect_and_recognize'}
        to_process_img = Image.open(BytesIO(base64.b64decode(
            request.image))).convert('RGB')
        polys_and_texts, _, _ = Toolbox.predict(
            to_predict_img=to_process_img,
            model=self.model,
            with_img=False,
            output_dir=None,
            with_gpu=self.config['cuda'],
            output_txt_dir=None,
            labels=None,
            label_converter=self.label_converter)
        if polys_and_texts is not None and len(polys_and_texts) > 0:
            to_return['code'] = 200
            to_return['result'] = max(
                polys_and_texts, key=lambda x: self._area_by_shoelace(x[0]))[1]
        else:
            to_return['code'] = 201
            to_return['result'] = '未识别出'
        return base_pb2.OCRResponse(message=json.dumps(to_return))
示例#5
0
def load_model(model_path, with_gpu):
    logger.info("Loading checkpoint: {} ...".format(model_path))
    checkpoints = tf.saved_model.load(model_path, map_location='cpu')
    if not checkpoints:
        raise RuntimeError('No checkpoint found.')
    config = checkpoints['config']
    state_dict = checkpoints['state_dict']
    model = FOTSModel(config)
    # if with_gpu:
    #     model.parallelize()
    # model.load_state_dict(state_dict)
    if with_gpu:
        model.to(tf.device('cuda'))
    model.eval()
    return model
示例#6
0
def load_model(model_path, with_gpu):
    logger.info("Loading checkpoint: {} ...".format(model_path))
    checkpoints = torch.load(
        "/Users/xingoo/PycharmProjects/OCR_detection_IC15/saved/FOTS/model_best.pth.tar",
        map_location='cpu')
    if not checkpoints:
        raise RuntimeError('No checkpoint found.')
    config = checkpoints['config']
    state_dict = checkpoints['state_dict']

    model = FOTSModel(config['model'])

    model.load_state_dict(state_dict)

    model = torch.nn.DataParallel(model)

    if with_gpu:
        model = model.cuda()
    model = model.eval()
    return model
示例#7
0
def load_model(model_path, with_gpu):  # 载入模型、模型参数
    logger.info("Loading checkpoint: {} ...".format(model_path))
    checkpoint = torch.load(model_path)
    if not checkpoint:
        raise RuntimeError('No checkpoint found.')
    config = checkpoint['config']

    model = FOTSModel(config)

    pretrained_dict = checkpoint['state_dict']  # 预训练模型的state_dict
    model_dict = model.state_dict()  # 当前用来训练的模型的state_dict

    if pretrained_dict.keys() != model_dict.keys():  # 需要进行参数的适配
        print('Parameters are inconsistant, adapting model parameters ...')
        # 在合并前(update),需要去除pretrained_dict一些不需要的参数
        # 只含有识别分支的预训练模型参数字典中键'0', '1'对应全模型参数字典中键'2', '3'
        pretrained_dict['2'] = transfer_state_dict(pretrained_dict['0'],
                                                   model_dict['2'])
        pretrained_dict['3'] = transfer_state_dict(pretrained_dict['1'],
                                                   model_dict['3'])
        del pretrained_dict['0']  # 把原本预训练模型中的键值对删掉,以免错误地更新当前模型中的键值对
        del pretrained_dict['1']
        model_dict.update(pretrained_dict)  # 更新(合并)模型的参数
        self.model.load_state_dict(model_dict)
    else:
        print('Parameters are consistant, load state dict directly ...\n')
        model.load_state_dict(pretrained_dict)

    if with_gpu:
        model.to(torch.device("cuda:0"))
        model.parallelize()

    model.eval()
    return model
示例#8
0
                    for m_name, m_value in zip([
                        'left_top', 'right_top', 'right_bottom', 'left_bottom'
                    ], m_box)
                })
            self.write(json.dumps(to_return))


if __name__ == "__main__":
    ag = ArgumentParser()
    ag.add_argument("-c", type=str, help='path to config file')
    args = ag.parse_args()

    with open(args.c, mode='r', encoding='utf-8') as to_read:
        config = json.loads(to_read.read())

    model = FOTSModel(config, False)
    model.eval()
    model.load_state_dict(torch.load(config['model_path'])['state_dict'])
    label_converter = strLabelConverter(
        getattr(common_str, config['model']['keys']))
    with_gpu = config['cuda']
    if with_gpu:
        model.to(torch.device("cuda:0"))

    routes = [
        ('/detect', DetectHandler),
        ('/recognize', DetectHandler),
        ('/detect_and_recognize', DetectHandler),
    ]

    application = tornado.web.Application(routes)