示例#1
0
def load_model(config):
    model = CChessModel(config)
    if config.internet.distributed or config.opts.new or not load_best_model_weight(
            model):
        model.build()
        save_as_best_model(model)
    return model
示例#2
0
def load_model(config, weight_path, digest, config_file=None):
    model = CChessModel(config)
    use_history = False
    if not config_file:
        config_path = config.resource.model_best_config_path
        use_history = False
    else:
        config_path = os.path.join(config.resource.model_dir, config_file)
    logger.debug(f"config_path = {config_path}, digest = {digest}")
    if (not load_model_weight(model, config_path, weight_path)) or model.digest != digest:
        logger.info(f"开始下载权重 {digest[0:8]}")
        url = config.internet.download_base_url + digest + '.h5'
        download_file(url, weight_path)
        try:
            if not load_model_weight(model, config_path, weight_path):
                logger.info(f"待评测权重还未上传,请稍后再试")
                sys.exit()
        except ValueError as e:
            logger.error(f"权重架构不匹配,自动重新加载 {e}")
            return load_model(config, weight_path, digest, 'model_192x10_config.json')
        except Exception as e:
            logger.error(f"加载权重发生错误:{e},10s后自动重试下载")
            os.remove(weight_path)
            sleep(10)
            return load_model(config, weight_path, digest)
    logger.info(f"加载权重 {digest[0:8]} 成功")
    return model, use_history
示例#3
0
 def load_model(self):
     sess = set_session_config(per_process_gpu_memory_fraction=1,
                               allow_growth=True,
                               device_list=self.config.opts.device_list)
     self.model = CChessModel(self.config)
     if self.config.opts.new or not load_best_model_weight(self.model):
         self.model.build()
     self.model.sess = sess
示例#4
0
def plot_model():
    from keras.utils import plot_model
    from cchess_alphazero.agent.model import CChessModel
    from cchess_alphazero.config import Config
    from cchess_alphazero.lib.model_helper import save_as_best_model
    config = Config('distribute')
    model = CChessModel(config)
    model.build()
    save_as_best_model(model)
    plot_model(model.model,
               to_file='model.png',
               show_shapes=True,
               show_layer_names=True)
示例#5
0
def load_model(config, model_file=None):
    use_history = False
    model = CChessModel(config)
    if not model_file:
        config_path = config.resource.model_best_path
        use_history = False
    else:
        config_path = os.path.join(config.resource.model_dir, model_file)
    try:
        if not load_model_weight(model, config_path):
            # model.build()
            save_as_best_model(model)
            use_history = True
    except Exception as e:
        logger.info(f"Exception {e}, 重新加载权重")
        return load_model(config, model_file='model_192x10_config.json')
    return model, use_history
 def load_model(self, model_file=None):
     use_history = False
     model = CChessModel(self.config)
     # weight_path = self.config.resource.model_best_weight_path
     if not model_file:
         model_path = self.config.resource.model_best_path
         use_history = False
     else:
         model_path = os.path.join(self.config.resource.model_dir,
                                   model_file)
     try:
         if not load_model_weight(model, model_path):
             save_as_best_model(model)
             use_history = True
     except Exception as e:
         logger.info(f"Exception {e}, 重新加载权重")
         return self.load_model(model_file='model_192x10_config.json')
     return model, use_history
示例#7
0
 def load_model(self, config_file=None):
     use_history = True
     self.model = CChessModel(self.config)
     weight_path = self.config.resource.model_best_weight_path
     if not config_file:
         config_path = config.resource.model_best_path
         use_history = False
     else:
         config_path = os.path.join(config.resource.model_dir, config_file)
     try:
         if not load_model_weight(self.model, config_path, weight_path):
             self.model.build()
             use_history = True
     except Exception as e:
         logger.info(f"Exception {e}, 重新加载权重")
         return self.load_model(config_file='model_128_l1_config.json')
     logger.info(f"use_history = {use_history}")
     return use_history
示例#8
0
 def load_model(self):
     self.model = CChessModel(self.config)
     if self.config.opts.new or not load_best_model_weight(self.model):
         self.model.build()
示例#9
0
 def load_model(self):
     model = CChessModel(self.config)
     if self.config.opts.new or not load_sl_best_model_weight(model):
         model.build()
         save_as_sl_best_model(model)
     return model
示例#10
0
def load_model(config, config_path, weight_path, name=None):
    model = CChessModel(config)
    if not load_model_weight(model, config_path, weight_path, name):
        return None
    return model