예제 #1
0
        url = img_file.replace(IMG_PATH, IMG_URL)
        return {
            'code': 0,
            'msg': 'success',
            'data': {
                'url': url,
                'objects': res,
                'caption': caption,
                'tts_caption': tts_url
            }
        }


if __name__ == '__main__':
    """服务入口"""
    port = 8026

    # log init
    log_file = ApiObjectDetect.__name__.lower()  # + '-' + str(os.getpid())
    utils.init_logging(log_file=log_file, log_path=APP_PATH)
    print("log_file: {}".format(log_file))

    # 路由
    app = tornado.web.Application(handlers=[(r'/piglab/image/object_detect',
                                             ApiObjectDetect)])

    # 启动服务
    http_server = tornado.httpserver.HTTPServer(app, xheaders=True)
    http_server.listen(port)
    tornado.ioloop.IOLoop.instance().start()
예제 #2
0
                if (i + 1) % self.check_freq == 0:
                    logging.info("TRAIN Current self-play batch: {}".format(i + 1))
                    # 策略胜率评估:模型与纯MCTS玩家对战n局看胜率
                    win_ratio = self.policy_evaluate(self.policy_evaluate_size)
                    self.policy_value_net.save_model(CUR_PATH + '/model/current_policy_{}x{}.model'.format(self.board_width, self.board_height))
                    if win_ratio > self.best_win_ratio:  # 胜率超过历史最优模型
                        logging.info("TRAIN New best policy!!!!!!!!batch:{} win_ratio:{}->{} pure_mcts_playout_num:{}".format(i + 1, self.best_win_ratio, win_ratio, self.pure_mcts_playout_num))
                        self.best_win_ratio = win_ratio
                        # 保存当前模型为最优模型best_policy
                        self.policy_value_net.save_model(CUR_PATH + '/model/best_policy_{}x{}.model'.format(self.board_width, self.board_height))
                        # 如果胜率=100%,则增加纯MCT的模拟数 (<6000的限制视mem情况)
                        if self.best_win_ratio == 1.0: # and self.pure_mcts_playout_num < 6000:
                            self.pure_mcts_playout_num += 1000
                            self.best_win_ratio = 0.0
        except KeyboardInterrupt:
            logging.info('\n\rquit')


if __name__ == '__main__':
    # log init
    utils.init_logging(log_file='train', log_path=CUR_PATH)

    # train
    size = 16  # 棋盘大小
    model_file = '{}/model/current_policy_{}x{}.model'.format(CUR_PATH, size, size)
    print("model file exists: {}".format(os.path.exists(model_file)))
    model_file = '' if os.path.exists(model_file) is False else model_file
    #training_pipeline = GomokuTrainPipeline(size=size)
    training_pipeline = GomokuTrainPipeline(init_model=model_file, size=size)
    training_pipeline.run()
예제 #3
0
파일: da.py 프로젝트: linruohan/pigrobot
# -*- coding: utf-8 -*-
"""
File: da.py
Desc: 需求识别基类
Author:yanjingang([email protected])
Date: 2019/2/21 23:34
"""

import logging
from dp import utils
from dp.da import Da
import constants

if __name__ == '__main__':
    """test"""
    utils.init_logging(log_file='da', log_path=constants.APP_PATH)

    da = Da(dict_path=constants.APP_PATH + "/data/da/")
    # 意图分类
    #res = da.get_trigger("宝马3系报价")
    res = da.get_trigger("哪个恐龙跑的最快?")
    print("trigger: {}".format(res))
    # 意图分类下的query解析
    #res = da.get_parser(res[0]['type'], "宝马3系报价")
    res = da.get_parser(res[0]['type'], "哪个恐龙跑的最快?")
    #res = da.get_parser(res[0]['type'], "跑的最快的恐龙")
    print("parser: {}".format(res))
    '''
    # tree test
    tree = Node(name='root')
    ctree = tree.add_child(Node('car'))