示例#1
0
    def __init__(self, config: Config, model, play_config=None):

        self.config = config
        self.model = model
        self.play_config = play_config or self.config.play
        self.api = ChessModelAPI(self.config, self.model)

        self.move_lookup = {
            k: v
            for k, v in zip((
                chess.Move.from_uci(mov)
                for mov in self.config.labels), range(len(self.config.labels)))
        }

        self.labels_n = config.n_labels
        self.var_n = defaultdict(lambda: np.zeros((self.labels_n, )))
        self.var_w = defaultdict(lambda: np.zeros((self.labels_n, )))
        self.var_q = defaultdict(lambda: np.zeros((self.labels_n, )))
        self.var_u = defaultdict(lambda: np.zeros((self.labels_n, )))
        self.var_p = defaultdict(lambda: np.zeros((self.labels_n, )))
        self.expanded = set()
        self.now_expanding = set()
        self.prediction_queue = Queue(self.play_config.prediction_queue_size)
        self.sem = asyncio.Semaphore(self.play_config.parallel_search_num)

        self.moves = []
        self.loop = asyncio.get_event_loop()
        self.running_simulation_num = 0

        self.thinking_history = {}  # for fun
示例#2
0
    def get_pipes(self, num=1):
        """
        Creates a list of pipes on which observations of the game state will be listened for. Whenever
        an observation comes in, returns policy and value network predictions on that pipe.

        :param int num: number of pipes to create
        :return str(Connection): a list of all connections to the pipes that were created
        """
        if self.api is None:
            self.api = ChessModelAPI(self)
            self.api.start()
        return [self.api.create_pipe() for _ in range(num)]
示例#3
0
 def get_api_queue(self):
     if self.queue is None:
         self.queue = ChessModelAPI(self.config, self).prediction_queue
     return self.queue
示例#4
0
class ChessModel:
    """
    The model which can be trained to take observations of a game of chess and return value and policy
    predictions.

    Attributes:
        :ivar Config config: configuration to use
        :ivar Model model: the Keras model to use for predictions
        :ivar digest: basically just a hash of the file containing the weights being used by this model
        :ivar ChessModelAPI api: the api to use to listen for and then return this models predictions (on a pipe).
    """
    def __init__(self, config: Config):
        self.config = config
        self.model = None  # type: Model
        self.digest = None
        self.api = None

    def get_pipes(self, num=1):
        """
        Creates a list of pipes on which observations of the game state will be listened for. Whenever
        an observation comes in, returns policy and value network predictions on that pipe.

        :param int num: number of pipes to create
        :return str(Connection): a list of all connections to the pipes that were created
        """
        if self.api is None:
            self.api = ChessModelAPI(self)
            self.api.start()
        return [self.api.create_pipe() for _ in range(num)]

    def build(self):
        """
        Builds the full Keras model and stores it in self.model.
        """
        mc = self.config.model
        in_x = x = Input((18, 8, 8))

        # (batch, channels, height, width)
        x = Conv2D(filters=mc.cnn_filter_num,
                   kernel_size=mc.cnn_first_filter_size,
                   padding="same",
                   data_format="channels_first",
                   use_bias=False,
                   kernel_regularizer=l2(mc.l2_reg),
                   name="input_conv-" + str(mc.cnn_first_filter_size) + "-" +
                   str(mc.cnn_filter_num))(x)
        x = BatchNormalization(axis=1, name="input_batchnorm")(x)
        x = Activation("relu", name="input_relu")(x)

        for i in range(mc.res_layer_num):
            x = self._build_residual_block(x, i + 1)

        res_out = x

        # for policy output
        x = Conv2D(filters=2,
                   kernel_size=1,
                   data_format="channels_first",
                   use_bias=False,
                   kernel_regularizer=l2(mc.l2_reg),
                   name="policy_conv-1-2")(res_out)
        x = BatchNormalization(axis=1, name="policy_batchnorm")(x)
        x = Activation("relu", name="policy_relu")(x)
        x = Flatten(name="policy_flatten")(x)
        # no output for 'pass'
        policy_out = Dense(self.config.n_labels,
                           kernel_regularizer=l2(mc.l2_reg),
                           activation="softmax",
                           name="policy_out")(x)

        # for value output
        x = Conv2D(filters=4,
                   kernel_size=1,
                   data_format="channels_first",
                   use_bias=False,
                   kernel_regularizer=l2(mc.l2_reg),
                   name="value_conv-1-4")(res_out)
        x = BatchNormalization(axis=1, name="value_batchnorm")(x)
        x = Activation("relu", name="value_relu")(x)
        x = Flatten(name="value_flatten")(x)
        x = Dense(mc.value_fc_size,
                  kernel_regularizer=l2(mc.l2_reg),
                  activation="relu",
                  name="value_dense")(x)
        value_out = Dense(1,
                          kernel_regularizer=l2(mc.l2_reg),
                          activation="tanh",
                          name="value_out")(x)

        self.model = Model(in_x, [policy_out, value_out], name="chess_model")

    def _build_residual_block(self, x, index):
        mc = self.config.model
        in_x = x
        res_name = "res" + str(index)
        x = Conv2D(filters=mc.cnn_filter_num,
                   kernel_size=mc.cnn_filter_size,
                   padding="same",
                   data_format="channels_first",
                   use_bias=False,
                   kernel_regularizer=l2(mc.l2_reg),
                   name=res_name + "_conv1-" + str(mc.cnn_filter_size) + "-" +
                   str(mc.cnn_filter_num))(x)
        x = BatchNormalization(axis=1, name=res_name + "_batchnorm1")(x)
        x = Activation("relu", name=res_name + "_relu1")(x)
        x = Conv2D(filters=mc.cnn_filter_num,
                   kernel_size=mc.cnn_filter_size,
                   padding="same",
                   data_format="channels_first",
                   use_bias=False,
                   kernel_regularizer=l2(mc.l2_reg),
                   name=res_name + "_conv2-" + str(mc.cnn_filter_size) + "-" +
                   str(mc.cnn_filter_num))(x)
        x = BatchNormalization(axis=1,
                               name="res" + str(index) + "_batchnorm2")(x)
        x = Add(name=res_name + "_add")([in_x, x])
        x = Activation("relu", name=res_name + "_relu2")(x)
        return x

    @staticmethod
    def fetch_digest(weight_path):
        if os.path.exists(weight_path):
            m = hashlib.sha256()
            with open(weight_path, "rb") as f:
                m.update(f.read())
            return m.hexdigest()

    def load(self, config_path, weight_path):
        """

        :param str config_path: path to the file containing the entire configuration
        :param str weight_path: path to the file containing the model weights
        :return: true iff successful in loading
        """
        mc = self.config.model
        resources = self.config.resource
        if mc.distributed and config_path == resources.model_best_config_path:
            try:
                logger.debug("loading model from server")
                ftp_connection = ftplib.FTP(
                    resources.model_best_distributed_ftp_server,
                    resources.model_best_distributed_ftp_user,
                    resources.model_best_distributed_ftp_password)
                ftp_connection.cwd(
                    resources.model_best_distributed_ftp_remote_path)
                ftp_connection.retrbinary("RETR model_best_config.json",
                                          open(config_path, 'wb').write)
                ftp_connection.retrbinary("RETR model_best_weight.h5",
                                          open(weight_path, 'wb').write)
                ftp_connection.quit()
            except:
                pass
        if os.path.exists(config_path) and os.path.exists(weight_path):
            logger.debug(f"loading model from {config_path}")
            with open(config_path, "rt") as f:
                self.model = Model.from_config(json.load(f))
            self.model.load_weights(weight_path)
            self.model._make_predict_function()
            self.digest = self.fetch_digest(weight_path)
            logger.debug(f"loaded model digest = {self.digest}")
            return True
        else:
            logger.debug(
                f"model files does not exist at {config_path} and {weight_path}"
            )
            return False

    def save(self, config_path, weight_path):
        """

        :param str config_path: path to save the entire configuration to
        :param str weight_path: path to save the model weights to
        """
        logger.debug(f"save model to {config_path}")
        with open(config_path, "wt") as f:
            json.dump(self.model.get_config(), f)
            self.model.save_weights(weight_path)
        self.digest = self.fetch_digest(weight_path)
        logger.debug(f"saved model digest {self.digest}")

        mc = self.config.model
        resources = self.config.resource
        if mc.distributed and config_path == resources.model_best_config_path:
            try:
                logger.debug("saving model to server")
                ftp_connection = ftplib.FTP(
                    resources.model_best_distributed_ftp_server,
                    resources.model_best_distributed_ftp_user,
                    resources.model_best_distributed_ftp_password)
                ftp_connection.cwd(
                    resources.model_best_distributed_ftp_remote_path)
                fh = open(config_path, 'rb')
                ftp_connection.storbinary('STOR model_best_config.json', fh)
                fh.close()

                fh = open(weight_path, 'rb')
                ftp_connection.storbinary('STOR model_best_weight.h5', fh)
                fh.close()
                ftp_connection.quit()
            except:
                pass
示例#5
0
class ChessPlayer:
    def __init__(self, config: Config, model, play_config=None):

        self.config = config
        self.model = model
        self.play_config = play_config or self.config.play
        self.api = ChessModelAPI(self.config, self.model)

        self.move_lookup = {
            k: v
            for k, v in zip((
                chess.Move.from_uci(mov)
                for mov in self.config.labels), range(len(self.config.labels)))
        }

        self.labels_n = config.n_labels
        self.var_n = defaultdict(lambda: np.zeros((self.labels_n, )))
        self.var_w = defaultdict(lambda: np.zeros((self.labels_n, )))
        self.var_q = defaultdict(lambda: np.zeros((self.labels_n, )))
        self.var_u = defaultdict(lambda: np.zeros((self.labels_n, )))
        self.var_p = defaultdict(lambda: np.zeros((self.labels_n, )))
        self.expanded = set()
        self.now_expanding = set()
        self.prediction_queue = Queue(self.play_config.prediction_queue_size)
        self.sem = asyncio.Semaphore(self.play_config.parallel_search_num)

        self.moves = []
        self.loop = asyncio.get_event_loop()
        self.running_simulation_num = 0

        self.thinking_history = {}  # for fun

    def action(self, board):

        env = ChessEnv().update(board)
        key = self.counter_key(env)

        for tl in range(self.play_config.thinking_loop):
            if tl > 0 and self.play_config.logging_thinking:
                logger.debug(
                    f"continue thinking: policy move=({action % 8}, {action // 8}), "
                    f"value move=({action_by_value % 8}, {action_by_value // 8})"
                )
            self.search_moves(board)
            policy = self.calc_policy(board)
            action = int(np.random.choice(range(self.labels_n), p=policy))
            action_by_value = int(
                np.argmax(self.var_q[key] + (self.var_n[key] > 0) * 100))
            if action == action_by_value or env.turn < self.play_config.change_tau_turn:
                break

        # this is for play_gui, not necessary when training.
        self.thinking_history[env.observation] = HistoryItem(
            action, policy, list(self.var_q[key]), list(self.var_n[key]))

        if self.play_config.resign_threshold is not None and \
            env.score_current() <= self.play_config.resign_threshold and \
                self.play_config.min_resign_turn < env.turn:
            return None  # means resign
        else:
            self.moves.append([env.observation, list(policy)])
            return self.config.labels[action]

    def ask_thought_about(self, board) -> HistoryItem:
        return self.thinking_history.get(board)

    @profile
    def search_moves(self, board):
        start = time.time()
        loop = self.loop
        self.running_simulation_num = 0

        coroutine_list = []
        for it in range(self.play_config.simulation_num_per_move):
            cor = self.start_search_my_move(board)
            coroutine_list.append(cor)

        coroutine_list.append(self.prediction_worker())
        loop.run_until_complete(asyncio.gather(*coroutine_list))
        #logger.debug(f"Search time per move: {time.time()-start}")
        # uncomment to see profile result per move
        # raise

    async def start_search_my_move(self, board):
        self.running_simulation_num += 1
        with await self.sem:  # reduce parallel search number
            env = ChessEnv().update(board)
            leaf_v = await self.search_my_move(env, is_root_node=True)
            self.running_simulation_num -= 1
            return leaf_v

    async def search_my_move(self, env: ChessEnv, is_root_node=False):
        """

        Q, V is value for this Player(always white).
        P is value for the player of next_player (black or white)
        :param env:
        :param is_root_node:
        :return:
        """
        if env.done:
            if env.winner == Winner.white:
                return 1
            elif env.winner == Winner.black:
                return -1
            else:
                return 0

        key = self.counter_key(env)

        while key in self.now_expanding:
            await asyncio.sleep(self.config.play.wait_for_expanding_sleep_sec)

        # is leaf?
        if key not in self.expanded:  # reach leaf node
            leaf_v = await self.expand_and_evaluate(env)
            if env.board.turn == chess.WHITE:
                return leaf_v  # Value for white
            else:
                return -leaf_v  # Value for white == -Value for white

        action_t = self.select_action_q_and_u(env, is_root_node)

        _, _ = env.step(self.config.labels[action_t])

        virtual_loss = self.config.play.virtual_loss
        self.var_n[key][action_t] += virtual_loss
        self.var_w[key][action_t] -= virtual_loss

        leaf_v = await self.search_my_move(env)  # next move

        # on returning search path
        # update: N, W, Q, U
        n = self.var_n[key][
            action_t] = self.var_n[key][action_t] - virtual_loss + 1
        w = self.var_w[key][
            action_t] = self.var_w[key][action_t] + virtual_loss + leaf_v
        self.var_q[key][action_t] = w / n

        return leaf_v

    @profile
    async def expand_and_evaluate(self, env):
        """expand new leaf

        update var_p, return leaf_v

        :param ChessEnv env:
        :return: leaf_v
        """
        key = self.counter_key(env)
        self.now_expanding.add(key)

        black_ary, white_ary = env.black_and_white_plane()
        state = [black_ary, white_ary] if env.board.turn == chess.BLACK else [
            white_ary, black_ary
        ]
        future = await self.predict(np.array(state))  # type: Future
        await future
        leaf_p, leaf_v = future.result()

        self.var_p[key] = leaf_p  # P is value for next_player (black or white)
        self.expanded.add(key)
        self.now_expanding.remove(key)
        return float(leaf_v)

    async def prediction_worker(self):
        """For better performance, queueing prediction requests and predict together in this worker.

        speed up about 45sec -> 15sec for example.
        :return:
        """
        q = self.prediction_queue
        margin = 10  # avoid finishing before other searches starting.
        while self.running_simulation_num > 0 or margin > 0:
            if q.empty():
                if margin > 0:
                    margin -= 1
                await asyncio.sleep(
                    self.config.play.prediction_worker_sleep_sec)
                continue
            item_list = [q.get_nowait()
                         for _ in range(q.qsize())]  # type: list[QueueItem]
            #logger.debug(f"predicting {len(item_list)} items")
            data = np.array([x.state for x in item_list])
            policy_ary, value_ary = self.api.predict(data)
            for p, v, item in zip(policy_ary, value_ary, item_list):
                item.future.set_result((p, v))

    async def predict(self, x):
        future = self.loop.create_future()
        item = QueueItem(x, future)
        await self.prediction_queue.put(item)
        return future

    def finish_game(self, z):
        """

        :param z: win=1, lose=-1, draw=0
        :return:
        """
        for move in self.moves:  # add this game winner result to all past moves.
            move += [z]

    def calc_policy(self, board):
        """calc π(a|s0)
        :return:
        """
        pc = self.play_config
        env = ChessEnv().update(board)
        key = self.counter_key(env)
        if env.turn < pc.change_tau_turn:
            return self.var_n[key] / (np.sum(self.var_n[key]) + 1e-8
                                      )  # tau = 1
        else:
            action = np.argmax(self.var_n[key])  # tau = 0
            ret = np.zeros(self.labels_n)
            ret[action] = 1
            return ret

    @staticmethod
    def counter_key(env: ChessEnv):
        return CounterKey(env.replace_tags(), env.board.turn)

    def select_action_q_and_u(self, env, is_root_node):
        key = self.counter_key(env)
        """Bottlenecks are these two lines"""
        legal_moves = [self.move_lookup[mov] for mov in env.board.legal_moves]
        legal_labels = np.zeros(len(self.config.labels))
        #logger.debug(legal_moves)
        legal_labels[legal_moves] = 1

        # noinspection PyUnresolvedReferences
        xx_ = np.sqrt(np.sum(
            self.var_n[key]))  # SQRT of sum(N(s, b); for all b)
        xx_ = max(xx_, 1)  # avoid u_=0 if N is all 0
        p_ = self.var_p[key]

        if is_root_node:  # Is it correct?? -> (1-e)p + e*Dir(0.03)
            p_ = (1 - self.play_config.noise_eps) * p_ + \
                 self.play_config.noise_eps * np.random.dirichlet([self.play_config.dirichlet_alpha] * self.labels_n)

        u_ = self.play_config.c_puct * p_ * xx_ / (1 + self.var_n[key])
        if env.board.turn == chess.WHITE:
            v_ = (self.var_q[key] + u_ + 1000) * legal_labels
        else:
            # When enemy's selecting action, flip Q-Value.
            v_ = (-self.var_q[key] + u_ + 1000) * legal_labels

        # noinspection PyTypeChecker
        action_t = int(np.argmax(v_))
        return action_t
 def get_pipes(self, num=1):
     if self.api is None:
         self.api = ChessModelAPI(self.config, self)
         self.api.start()
     return [self.api.get_pipe() for _ in range(num)]
class ChessModel:
    def __init__(self, config: Config):
        self.config = config
        self.model = None  # type: Model
        self.digest = None
        self.api = None

    def get_pipes(self, num=1):
        if self.api is None:
            self.api = ChessModelAPI(self.config, self)
            self.api.start()
        return [self.api.get_pipe() for _ in range(num)]

    def build(self):
        mc = self.config.model
        # in_x = x = Input((18, 8, 8))
        in_x = x = Input((14, 10, 9))  # change to CC

        # (batch, channels, height, width)
        x = Conv2D(filters=mc.cnn_filter_num,
                   kernel_size=mc.cnn_first_filter_size,
                   padding="same",
                   data_format="channels_first",
                   use_bias=False,
                   kernel_regularizer=l2(mc.l2_reg),
                   name="input_conv-" + str(mc.cnn_first_filter_size) + "-" +
                   str(mc.cnn_filter_num))(x)
        x = BatchNormalization(axis=1, name="input_batchnorm")(x)
        x = Activation("relu", name="input_relu")(x)

        for i in range(mc.res_layer_num):
            x = self._build_residual_block(x, i + 1)

        res_out = x

        # for policy output
        x = Conv2D(filters=2,
                   kernel_size=1,
                   data_format="channels_first",
                   use_bias=False,
                   kernel_regularizer=l2(mc.l2_reg),
                   name="policy_conv-1-2")(res_out)
        x = BatchNormalization(axis=1, name="policy_batchnorm")(x)
        x = Activation("relu", name="policy_relu")(x)
        x = Flatten(name="policy_flatten")(x)
        # no output for 'pass'
        policy_out = Dense(self.config.n_labels,
                           kernel_regularizer=l2(mc.l2_reg),
                           activation="softmax",
                           name="policy_out")(x)

        # for value output
        x = Conv2D(filters=4,
                   kernel_size=1,
                   data_format="channels_first",
                   use_bias=False,
                   kernel_regularizer=l2(mc.l2_reg),
                   name="value_conv-1-4")(res_out)
        x = BatchNormalization(axis=1, name="value_batchnorm")(x)
        x = Activation("relu", name="value_relu")(x)
        x = Flatten(name="value_flatten")(x)
        x = Dense(mc.value_fc_size,
                  kernel_regularizer=l2(mc.l2_reg),
                  activation="relu",
                  name="value_dense")(x)
        value_out = Dense(1,
                          kernel_regularizer=l2(mc.l2_reg),
                          activation="tanh",
                          name="value_out")(x)

        self.model = Model(in_x, [policy_out, value_out], name="chess_model")

    def _build_residual_block(self, x, index):
        mc = self.config.model
        in_x = x
        res_name = "res" + str(index)
        x = Conv2D(filters=mc.cnn_filter_num,
                   kernel_size=mc.cnn_filter_size,
                   padding="same",
                   data_format="channels_first",
                   use_bias=False,
                   kernel_regularizer=l2(mc.l2_reg),
                   name=res_name + "_conv1-" + str(mc.cnn_filter_size) + "-" +
                   str(mc.cnn_filter_num))(x)
        x = BatchNormalization(axis=1, name=res_name + "_batchnorm1")(x)
        x = Activation("relu", name=res_name + "_relu1")(x)
        x = Conv2D(filters=mc.cnn_filter_num,
                   kernel_size=mc.cnn_filter_size,
                   padding="same",
                   data_format="channels_first",
                   use_bias=False,
                   kernel_regularizer=l2(mc.l2_reg),
                   name=res_name + "_conv2-" + str(mc.cnn_filter_size) + "-" +
                   str(mc.cnn_filter_num))(x)
        x = BatchNormalization(axis=1,
                               name="res" + str(index) + "_batchnorm2")(x)
        x = Add(name=res_name + "_add")([in_x, x])
        x = Activation("relu", name=res_name + "_relu2")(x)
        return x

    @staticmethod
    def fetch_digest(weight_path):
        if os.path.exists(weight_path):
            m = hashlib.sha256()
            with open(weight_path, "rb") as f:
                m.update(f.read())
            return m.hexdigest()

    def load(self, config_path, weight_path):
        mc = self.config.model
        resources = self.config.resource
        if mc.distributed and config_path == resources.model_best_config_path:
            try:
                logger.debug("loading model from server")
                ftp_connection = ftplib.FTP(
                    resources.model_best_distributed_ftp_server,
                    resources.model_best_distributed_ftp_user,
                    resources.model_best_distributed_ftp_password)
                ftp_connection.cwd(
                    resources.model_best_distributed_ftp_remote_path)
                ftp_connection.retrbinary("RETR model_best_config.json",
                                          open(config_path, 'wb').write)
                ftp_connection.retrbinary("RETR model_best_weight.h5",
                                          open(weight_path, 'wb').write)
                ftp_connection.quit()
            except:
                pass
        if os.path.exists(config_path) and os.path.exists(weight_path):
            logger.debug("loading model from %s" % (config_path))
            with open(config_path, "rt") as f:
                self.model = Model.from_config(json.load(f))
            self.model.load_weights(weight_path)
            self.model._make_predict_function()
            self.digest = self.fetch_digest(weight_path)
            logger.debug("loaded model digest = %s" % (self.digest))
            return True
        else:
            logger.debug("model files does not exist at %s and %s" %
                         (config_path, weight_path))
            return False

    def save(self, config_path, weight_path):
        logger.debug("saving model to %s" % (config_path))
        print('debug-3')
        with open(config_path, "wt") as f:
            print('debug-2')
            json.dump(self.model.get_config(), f)
            print('debug-1')
            self.model.save_weights(weight_path)
            print('debug-0')
        self.digest = self.fetch_digest(weight_path)
        logger.debug("saved model digest %s" % (self.digest))

        print('debug')

        mc = self.config.model
        resources = self.config.resource
        print('debug2')
        if mc.distributed and config_path == resources.model_best_config_path:
            try:
                print('debug3')
                logger.debug("saving model to server")
                ftp_connection = ftplib.FTP(
                    resources.model_best_distributed_ftp_server,
                    resources.model_best_distributed_ftp_user,
                    resources.model_best_distributed_ftp_password)
                ftp_connection.cwd(
                    resources.model_best_distributed_ftp_remote_path)
                fh = open(config_path, 'rb')
                ftp_connection.storbinary('STOR model_best_config.json', fh)
                fh.close()

                fh = open(weight_path, 'rb')
                ftp_connection.storbinary('STOR model_best_weight.h5', fh)
                fh.close()
                ftp_connection.quit()
            except:
                print('debug4')
                pass
class ChessModel:
    def __init__(self, config: Config):
        self.config = config
        self.model = None  # type: Model
        self.graph = None
        self.digest = None
        self.api = None

    def get_pipes(self, num=1):
        if self.api is None:
            self.api = ChessModelAPI(self)
            self.api.start()
        return [self.api.get_pipe() for _ in range(num)]

    def build(self):
        mc = self.config.model
        in_x = x = Input((mc.input_stack_height, 8, 8))
        # (batch, channels, height, width)
        x = Conv2D(filters=mc.cnn_filter_num,
                   kernel_size=mc.cnn_filter_size,
                   padding="same",
                   data_format="channels_first",
                   use_bias=False,
                   kernel_initializer='glorot_normal',
                   bias_initializer='zeros',
                   kernel_regularizer=l2(mc.l2_reg),
                   input_shape=(mc.input_stack_height, 8, 8))(x)
        x = BatchNormalization(axis=1)(x)
        x = Activation("relu")(x)

        for _ in range(mc.res_layer_num):
            x = self._build_residual_block(x)

        res_out = x
        # for policy output
        x = Conv2D(filters=2,
                   kernel_size=1,
                   data_format="channels_first",
                   use_bias=False,
                   kernel_initializer='glorot_normal',
                   bias_initializer='zeros',
                   kernel_regularizer=l2(mc.l2_reg))(res_out)
        x = BatchNormalization(axis=1)(x)
        x = Activation("relu")(x)
        x = Flatten()(x)
        # no output for 'pass'
        policy_out = Dense(self.config.n_labels,
                           kernel_initializer='glorot_normal',
                           bias_initializer='zeros',
                           kernel_regularizer=l2(mc.l2_reg),
                           activation="softmax",
                           name="policy_out")(x)

        # for value output
        x = Conv2D(filters=1,
                   kernel_size=1,
                   data_format="channels_first",
                   use_bias=False,
                   kernel_initializer='glorot_normal',
                   bias_initializer='zeros',
                   kernel_regularizer=l2(mc.l2_reg))(res_out)
        x = BatchNormalization(axis=1)(x)
        x = Activation("relu")(x)
        x = Flatten()(x)
        x = Dense(mc.value_fc_size,
                  kernel_initializer='glorot_normal',
                  bias_initializer='zeros',
                  kernel_regularizer=l2(mc.l2_reg),
                  activation="relu")(x)
        value_out = Dense(1,
                          kernel_initializer='glorot_normal',
                          bias_initializer='zeros',
                          kernel_regularizer=l2(mc.l2_reg),
                          activation="tanh",
                          name="value_out")(x)

        self.model = Model(in_x, [policy_out, value_out], name="chess_model")
        self.graph = get_default_graph()

    def _build_residual_block(self, x):
        mc = self.config.model
        in_x = x
        x = Conv2D(filters=mc.cnn_filter_num,
                   kernel_size=mc.cnn_filter_size,
                   padding="same",
                   data_format="channels_first",
                   use_bias=False,
                   kernel_initializer='glorot_normal',
                   bias_initializer='zeros',
                   kernel_regularizer=l2(mc.l2_reg))(x)
        x = BatchNormalization(axis=1)(x)
        x = Activation("relu")(x)
        x = Conv2D(filters=mc.cnn_filter_num,
                   kernel_size=mc.cnn_filter_size,
                   padding="same",
                   data_format="channels_first",
                   use_bias=False,
                   kernel_initializer='glorot_normal',
                   bias_initializer='zeros',
                   kernel_regularizer=l2(mc.l2_reg))(x)
        x = BatchNormalization(axis=1)(x)
        x = Add()([in_x, x])
        x = Activation("relu")(x)
        return x

    @staticmethod
    def fetch_digest(weight_path):
        if os.path.exists(weight_path):
            m = hashlib.sha256()
            with open(weight_path, "rb") as f:
                m.update(f.read())
            return m.hexdigest()

    def load(self, config_path, weight_path):
        if os.path.exists(config_path) and os.path.exists(weight_path):
            logger.debug(f"loading model from {config_path}")
            with open(config_path, "rt") as f:
                self.model = Model.from_config(json.load(f))
            self.model.load_weights(weight_path)
            self.graph = get_default_graph()
            # self.model._make_predict_function()
            self.digest = self.fetch_digest(weight_path)
            logger.debug(f"loaded model digest = {self.digest}")
            return True
        else:
            logger.debug(
                f"model files do not exist at {config_path} and {weight_path}")
            return False

    def save(self, config_path, weight_path):
        logger.debug(f"save model to {config_path}")
        with open(config_path, "wt") as f:
            json.dump(self.model.get_config(), f)
            self.model.save_weights(weight_path)
        self.digest = self.fetch_digest(weight_path)
        logger.debug(f"saved model digest {self.digest}")