class ChessModel:
    def __init__(self, config: Config):
        self.config = config
        self.model = None  # type: Model
        self.digest = None
        self.api = None

    def get_pipes(self, num=1):
        if self.api is None:
            self.api = ChessModelAPI(self.config, self)
            self.api.start()
        return [self.api.get_pipe() for _ in range(num)]

    def build(self):
        mc = self.config.model
        # in_x = x = Input((18, 8, 8))
        in_x = x = Input((14, 10, 9))  # change to CC

        # (batch, channels, height, width)
        x = Conv2D(filters=mc.cnn_filter_num,
                   kernel_size=mc.cnn_first_filter_size,
                   padding="same",
                   data_format="channels_first",
                   use_bias=False,
                   kernel_regularizer=l2(mc.l2_reg),
                   name="input_conv-" + str(mc.cnn_first_filter_size) + "-" +
                   str(mc.cnn_filter_num))(x)
        x = BatchNormalization(axis=1, name="input_batchnorm")(x)
        x = Activation("relu", name="input_relu")(x)

        for i in range(mc.res_layer_num):
            x = self._build_residual_block(x, i + 1)

        res_out = x

        # for policy output
        x = Conv2D(filters=2,
                   kernel_size=1,
                   data_format="channels_first",
                   use_bias=False,
                   kernel_regularizer=l2(mc.l2_reg),
                   name="policy_conv-1-2")(res_out)
        x = BatchNormalization(axis=1, name="policy_batchnorm")(x)
        x = Activation("relu", name="policy_relu")(x)
        x = Flatten(name="policy_flatten")(x)
        # no output for 'pass'
        policy_out = Dense(self.config.n_labels,
                           kernel_regularizer=l2(mc.l2_reg),
                           activation="softmax",
                           name="policy_out")(x)

        # for value output
        x = Conv2D(filters=4,
                   kernel_size=1,
                   data_format="channels_first",
                   use_bias=False,
                   kernel_regularizer=l2(mc.l2_reg),
                   name="value_conv-1-4")(res_out)
        x = BatchNormalization(axis=1, name="value_batchnorm")(x)
        x = Activation("relu", name="value_relu")(x)
        x = Flatten(name="value_flatten")(x)
        x = Dense(mc.value_fc_size,
                  kernel_regularizer=l2(mc.l2_reg),
                  activation="relu",
                  name="value_dense")(x)
        value_out = Dense(1,
                          kernel_regularizer=l2(mc.l2_reg),
                          activation="tanh",
                          name="value_out")(x)

        self.model = Model(in_x, [policy_out, value_out], name="chess_model")

    def _build_residual_block(self, x, index):
        mc = self.config.model
        in_x = x
        res_name = "res" + str(index)
        x = Conv2D(filters=mc.cnn_filter_num,
                   kernel_size=mc.cnn_filter_size,
                   padding="same",
                   data_format="channels_first",
                   use_bias=False,
                   kernel_regularizer=l2(mc.l2_reg),
                   name=res_name + "_conv1-" + str(mc.cnn_filter_size) + "-" +
                   str(mc.cnn_filter_num))(x)
        x = BatchNormalization(axis=1, name=res_name + "_batchnorm1")(x)
        x = Activation("relu", name=res_name + "_relu1")(x)
        x = Conv2D(filters=mc.cnn_filter_num,
                   kernel_size=mc.cnn_filter_size,
                   padding="same",
                   data_format="channels_first",
                   use_bias=False,
                   kernel_regularizer=l2(mc.l2_reg),
                   name=res_name + "_conv2-" + str(mc.cnn_filter_size) + "-" +
                   str(mc.cnn_filter_num))(x)
        x = BatchNormalization(axis=1,
                               name="res" + str(index) + "_batchnorm2")(x)
        x = Add(name=res_name + "_add")([in_x, x])
        x = Activation("relu", name=res_name + "_relu2")(x)
        return x

    @staticmethod
    def fetch_digest(weight_path):
        if os.path.exists(weight_path):
            m = hashlib.sha256()
            with open(weight_path, "rb") as f:
                m.update(f.read())
            return m.hexdigest()

    def load(self, config_path, weight_path):
        mc = self.config.model
        resources = self.config.resource
        if mc.distributed and config_path == resources.model_best_config_path:
            try:
                logger.debug("loading model from server")
                ftp_connection = ftplib.FTP(
                    resources.model_best_distributed_ftp_server,
                    resources.model_best_distributed_ftp_user,
                    resources.model_best_distributed_ftp_password)
                ftp_connection.cwd(
                    resources.model_best_distributed_ftp_remote_path)
                ftp_connection.retrbinary("RETR model_best_config.json",
                                          open(config_path, 'wb').write)
                ftp_connection.retrbinary("RETR model_best_weight.h5",
                                          open(weight_path, 'wb').write)
                ftp_connection.quit()
            except:
                pass
        if os.path.exists(config_path) and os.path.exists(weight_path):
            logger.debug("loading model from %s" % (config_path))
            with open(config_path, "rt") as f:
                self.model = Model.from_config(json.load(f))
            self.model.load_weights(weight_path)
            self.model._make_predict_function()
            self.digest = self.fetch_digest(weight_path)
            logger.debug("loaded model digest = %s" % (self.digest))
            return True
        else:
            logger.debug("model files does not exist at %s and %s" %
                         (config_path, weight_path))
            return False

    def save(self, config_path, weight_path):
        logger.debug("saving model to %s" % (config_path))
        print('debug-3')
        with open(config_path, "wt") as f:
            print('debug-2')
            json.dump(self.model.get_config(), f)
            print('debug-1')
            self.model.save_weights(weight_path)
            print('debug-0')
        self.digest = self.fetch_digest(weight_path)
        logger.debug("saved model digest %s" % (self.digest))

        print('debug')

        mc = self.config.model
        resources = self.config.resource
        print('debug2')
        if mc.distributed and config_path == resources.model_best_config_path:
            try:
                print('debug3')
                logger.debug("saving model to server")
                ftp_connection = ftplib.FTP(
                    resources.model_best_distributed_ftp_server,
                    resources.model_best_distributed_ftp_user,
                    resources.model_best_distributed_ftp_password)
                ftp_connection.cwd(
                    resources.model_best_distributed_ftp_remote_path)
                fh = open(config_path, 'rb')
                ftp_connection.storbinary('STOR model_best_config.json', fh)
                fh.close()

                fh = open(weight_path, 'rb')
                ftp_connection.storbinary('STOR model_best_weight.h5', fh)
                fh.close()
                ftp_connection.quit()
            except:
                print('debug4')
                pass
예제 #2
0
class ChessModel:
    def __init__(self, config: Config):
        self.config = config
        self.model = None  # type: Model
        self.graph = None
        self.digest = None
        self.api = None

    def get_pipes(self, num=1):
        if self.api is None:
            self.api = ChessModelAPI(self)
            self.api.start()
        return [self.api.get_pipe() for _ in range(num)]

    def build(self):
        mc = self.config.model
        in_x = x = Input((mc.input_stack_height, 8, 8))
        # (batch, channels, height, width)
        x = Conv2D(filters=mc.cnn_filter_num,
                   kernel_size=mc.cnn_filter_size,
                   padding="same",
                   data_format="channels_first",
                   use_bias=False,
                   kernel_initializer='glorot_normal',
                   bias_initializer='zeros',
                   kernel_regularizer=l2(mc.l2_reg),
                   input_shape=(mc.input_stack_height, 8, 8))(x)
        x = BatchNormalization(axis=1)(x)
        x = Activation("relu")(x)

        for _ in range(mc.res_layer_num):
            x = self._build_residual_block(x)

        res_out = x
        # for policy output
        x = Conv2D(filters=2,
                   kernel_size=1,
                   data_format="channels_first",
                   use_bias=False,
                   kernel_initializer='glorot_normal',
                   bias_initializer='zeros',
                   kernel_regularizer=l2(mc.l2_reg))(res_out)
        x = BatchNormalization(axis=1)(x)
        x = Activation("relu")(x)
        x = Flatten()(x)
        # no output for 'pass'
        policy_out = Dense(self.config.n_labels,
                           kernel_initializer='glorot_normal',
                           bias_initializer='zeros',
                           kernel_regularizer=l2(mc.l2_reg),
                           activation="softmax",
                           name="policy_out")(x)

        # for value output
        x = Conv2D(filters=1,
                   kernel_size=1,
                   data_format="channels_first",
                   use_bias=False,
                   kernel_initializer='glorot_normal',
                   bias_initializer='zeros',
                   kernel_regularizer=l2(mc.l2_reg))(res_out)
        x = BatchNormalization(axis=1)(x)
        x = Activation("relu")(x)
        x = Flatten()(x)
        x = Dense(mc.value_fc_size,
                  kernel_initializer='glorot_normal',
                  bias_initializer='zeros',
                  kernel_regularizer=l2(mc.l2_reg),
                  activation="relu")(x)
        value_out = Dense(1,
                          kernel_initializer='glorot_normal',
                          bias_initializer='zeros',
                          kernel_regularizer=l2(mc.l2_reg),
                          activation="tanh",
                          name="value_out")(x)

        self.model = Model(in_x, [policy_out, value_out], name="chess_model")
        self.graph = get_default_graph()

    def _build_residual_block(self, x):
        mc = self.config.model
        in_x = x
        x = Conv2D(filters=mc.cnn_filter_num,
                   kernel_size=mc.cnn_filter_size,
                   padding="same",
                   data_format="channels_first",
                   use_bias=False,
                   kernel_initializer='glorot_normal',
                   bias_initializer='zeros',
                   kernel_regularizer=l2(mc.l2_reg))(x)
        x = BatchNormalization(axis=1)(x)
        x = Activation("relu")(x)
        x = Conv2D(filters=mc.cnn_filter_num,
                   kernel_size=mc.cnn_filter_size,
                   padding="same",
                   data_format="channels_first",
                   use_bias=False,
                   kernel_initializer='glorot_normal',
                   bias_initializer='zeros',
                   kernel_regularizer=l2(mc.l2_reg))(x)
        x = BatchNormalization(axis=1)(x)
        x = Add()([in_x, x])
        x = Activation("relu")(x)
        return x

    @staticmethod
    def fetch_digest(weight_path):
        if os.path.exists(weight_path):
            m = hashlib.sha256()
            with open(weight_path, "rb") as f:
                m.update(f.read())
            return m.hexdigest()

    def load(self, config_path, weight_path):
        if os.path.exists(config_path) and os.path.exists(weight_path):
            logger.debug(f"loading model from {config_path}")
            with open(config_path, "rt") as f:
                self.model = Model.from_config(json.load(f))
            self.model.load_weights(weight_path)
            self.graph = get_default_graph()
            # self.model._make_predict_function()
            self.digest = self.fetch_digest(weight_path)
            logger.debug(f"loaded model digest = {self.digest}")
            return True
        else:
            logger.debug(
                f"model files do not exist at {config_path} and {weight_path}")
            return False

    def save(self, config_path, weight_path):
        logger.debug(f"save model to {config_path}")
        with open(config_path, "wt") as f:
            json.dump(self.model.get_config(), f)
            self.model.save_weights(weight_path)
        self.digest = self.fetch_digest(weight_path)
        logger.debug(f"saved model digest {self.digest}")