예제 #1
0
    def _load_model(self, model_dir):
        from reversi_zero.agent.model import ReversiModel
        model = ReversiModel(self.config)
        model.create_session()
        model.load(model_dir)

        return model
예제 #2
0
 def load_model(self):
     from reversi_zero.agent.model import ReversiModel
     model = ReversiModel(self.config)
     rc = self.config.resource
     model.create_session()
     model.load(rc.model_dir)
     return model
예제 #3
0
 def load_next_generation_model(self):
     rc = self.config.resource
     while True:
         dirs = get_next_generation_model_dirs(self.config.resource)
         if dirs:
             break
         logger.info(f"There is no next generation model to evaluate")
         sleep(60)
     model_dir = dirs[0]
     config_path = os.path.join(model_dir,
                                rc.next_generation_model_config_filename)
     weight_path = os.path.join(model_dir,
                                rc.next_generation_model_weight_filename)
     model = ReversiModel(self.config)
     model.load(config_path, weight_path)
     return model, model_dir
예제 #4
0
    def load_model(self):
        from reversi_zero.agent.model import ReversiModel
        model = ReversiModel(self.config)
        rc = self.config.resource

        dirs = get_next_generation_model_dirs(rc)
        if not dirs:
            logger.debug(f"loading best model")
            if not load_best_model_weight(model):
                raise RuntimeError(f"Best model can not loaded!")
        else:
            latest_dir = dirs[-1]
            logger.debug(f"loading latest model")
            config_path = os.path.join(latest_dir, rc.next_generation_model_config_filename)
            weight_path = os.path.join(latest_dir, rc.next_generation_model_weight_filename)
            model.load(config_path, weight_path)
        return model
예제 #5
0
 def load_model(self):
     from reversi_zero.agent.model import ReversiModel
     model = ReversiModel(self.config)
     model.build_train(self.config.resource.tensor_log_dir)
     model.create_session()
     logger.debug(f"loading model")
     steps = model.load(self.config.resource.model_dir)
     if steps is None:
         steps = 0
     return model, steps
예제 #6
0
def reload_newest_next_generation_model_if_changed(model, clear_session=False,config = None):
    """

    :param reversi_zero.agent.model.ReversiModel model:
    :param bool clear_session:
    :return:
    """
    from reversi_zero.lib.data_helper import get_next_generation_model_dirs
    from reversi_zero.agent.model import ReversiModel
    if config is not None:
        new_model = ReversiModel(config)
    rc = model.config.resource
    dirs = get_next_generation_model_dirs(rc)
    if not dirs:
        logger.debug("No next generation model exists.")
        return False
    model_dir = dirs[-1]
    config_path = os.path.join(model_dir, rc.next_generation_model_config_filename)
    weight_path = os.path.join(model_dir, rc.next_generation_model_weight_filename)
    digest = model.fetch_digest(weight_path)
    if digest and digest != model.digest:
        logger.debug(f"Loading weight from {model_dir}")
        if clear_session:
            K.clear_session()
        for _ in range(5):
            try:
                if config is not None:
                    del model
                    return new_model.load(config_path, weight_path)
                else:
                    return model.load(config_path, weight_path)
            except Exception as e:
                logger.warning(f"error in load model: #{e}")
                sleep(3)
        raise RuntimeError("Cannot Load Model!")

    else:
        logger.debug(f"The newest model is not changed: digest={digest}")
        return False
예제 #7
0
 def load_model(self, config_path, weight_path):
     model = ReversiModel(self.config)
     model.load(config_path, weight_path)
     return model