def load_stock_model(model_dir, max_seq_len): from tests.ext.modeling import BertModel, BertConfig, get_assignment_map_from_checkpoint tf.compat.v1.reset_default_graph( ) # to scope naming for checkpoint loading (if executed more than once) bert_config_file = os.path.join(model_dir, "bert_config.json") bert_ckpt_file = os.path.join(model_dir, "bert_model.ckpt") pl_input_ids = tf.compat.v1.placeholder(tf.int32, shape=(1, max_seq_len)) pl_mask = tf.compat.v1.placeholder(tf.int32, shape=(1, max_seq_len)) pl_token_type_ids = tf.compat.v1.placeholder(tf.int32, shape=(1, max_seq_len)) bert_config = BertConfig.from_json_file(bert_config_file) s_model = BertModel(config=bert_config, is_training=False, input_ids=pl_input_ids, input_mask=pl_mask, token_type_ids=pl_token_type_ids, use_one_hot_embeddings=False) tvars = tf.compat.v1.trainable_variables() (assignment_map, initialized_var_names) = get_assignment_map_from_checkpoint( tvars, bert_ckpt_file) tf.compat.v1.train.init_from_checkpoint(bert_ckpt_file, assignment_map) return s_model, pl_input_ids, pl_token_type_ids, pl_mask
def test_direct_keras_to_stock_compare(self): from tests.ext.tokenization import FullTokenizer from tests.ext.modeling import BertModel, BertConfig, get_assignment_map_from_checkpoint bert_config = BertConfig.from_json_file(self.bert_config_file) tokenizer = FullTokenizer( vocab_file=os.path.join(self.bert_ckpt_dir, "vocab.txt")) # prepare input max_seq_len = 6 input_str = "Hello, Bert!" input_tokens = tokenizer.tokenize(input_str) input_tokens = ["[CLS]"] + input_tokens + ["[SEP]"] input_ids = tokenizer.convert_tokens_to_ids(input_tokens) input_ids = input_ids + [0] * (max_seq_len - len(input_tokens)) input_mask = [1] * len(input_tokens) + [0] * (max_seq_len - len(input_tokens)) token_type_ids = [0] * len(input_tokens) + [0] * (max_seq_len - len(input_tokens)) input_ids = np.array([input_ids], dtype=np.int32) input_mask = np.array([input_mask], dtype=np.int32) token_type_ids = np.array([token_type_ids], dtype=np.int32) print(" tokens:", input_tokens) print( "input_ids:{}/{}:{}".format(len(input_tokens), max_seq_len, input_ids), input_ids.shape, token_type_ids) s_res = self.predict_on_stock_model(input_ids, input_mask, token_type_ids) k_res = self.predict_on_keras_model(input_ids, input_mask, token_type_ids) np.set_printoptions(precision=9, threshold=20, linewidth=200, sign="+", floatmode="fixed") print("s_res", s_res.shape) print("k_res", k_res.shape) print("s_res:\n {}".format(s_res[0, :2, :10]), s_res.dtype) print("k_res:\n {}".format(k_res[0, :2, :10]), k_res.dtype) adiff = np.abs(s_res - k_res).flatten() print("diff:", np.max(adiff), np.argmax(adiff)) self.assertTrue(np.allclose(s_res, k_res, atol=1e-6))
def create_stock_bert_graph(bert_config_file, max_seq_len): from tests.ext.modeling import BertModel, BertConfig tf_placeholder = tf.compat.v1.placeholder pl_input_ids = tf_placeholder(tf.int32, shape=(1, max_seq_len)) pl_mask = tf_placeholder(tf.int32, shape=(1, max_seq_len)) pl_token_type_ids = tf_placeholder(tf.int32, shape=(1, max_seq_len)) bert_config = BertConfig.from_json_file(bert_config_file) s_model = BertModel(config=bert_config, is_training=False, input_ids=pl_input_ids, input_mask=pl_mask, token_type_ids=pl_token_type_ids, use_one_hot_embeddings=False) return s_model, pl_input_ids, pl_mask, pl_token_type_ids
def predict_on_stock_model(self, input_ids, input_mask, token_type_ids): from tests.ext.modeling import BertModel, BertConfig, get_assignment_map_from_checkpoint tf.compat.v1.reset_default_graph() tf_placeholder = tf.compat.v1.placeholder max_seq_len = input_ids.shape[-1] pl_input_ids = tf.compat.v1.placeholder(tf.int32, shape=(1, max_seq_len)) pl_mask = tf.compat.v1.placeholder(tf.int32, shape=(1, max_seq_len)) pl_token_type_ids = tf.compat.v1.placeholder(tf.int32, shape=(1, max_seq_len)) bert_config = BertConfig.from_json_file(self.bert_config_file) tokenizer = FullTokenizer( vocab_file=os.path.join(self.bert_ckpt_dir, "vocab.txt")) s_model = BertModel(config=bert_config, is_training=False, input_ids=pl_input_ids, input_mask=pl_mask, token_type_ids=pl_token_type_ids, use_one_hot_embeddings=False) tvars = tf.compat.v1.trainable_variables() (assignment_map, initialized_var_names) = get_assignment_map_from_checkpoint( tvars, self.bert_ckpt_file) tf.compat.v1.train.init_from_checkpoint(self.bert_ckpt_file, assignment_map) with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) s_res = sess.run(s_model.get_sequence_output(), feed_dict={ pl_input_ids: input_ids, pl_token_type_ids: token_type_ids, pl_mask: input_mask, }) return s_res