def load_stock_model(model_dir, max_seq_len):
        from tests.ext.modeling import BertModel, BertConfig, get_assignment_map_from_checkpoint

        tf.compat.v1.reset_default_graph(
        )  # to scope naming for checkpoint loading (if executed more than once)

        bert_config_file = os.path.join(model_dir, "bert_config.json")
        bert_ckpt_file = os.path.join(model_dir, "bert_model.ckpt")

        pl_input_ids = tf.compat.v1.placeholder(tf.int32,
                                                shape=(1, max_seq_len))
        pl_mask = tf.compat.v1.placeholder(tf.int32, shape=(1, max_seq_len))
        pl_token_type_ids = tf.compat.v1.placeholder(tf.int32,
                                                     shape=(1, max_seq_len))

        bert_config = BertConfig.from_json_file(bert_config_file)

        s_model = BertModel(config=bert_config,
                            is_training=False,
                            input_ids=pl_input_ids,
                            input_mask=pl_mask,
                            token_type_ids=pl_token_type_ids,
                            use_one_hot_embeddings=False)

        tvars = tf.compat.v1.trainable_variables()
        (assignment_map,
         initialized_var_names) = get_assignment_map_from_checkpoint(
             tvars, bert_ckpt_file)
        tf.compat.v1.train.init_from_checkpoint(bert_ckpt_file, assignment_map)

        return s_model, pl_input_ids, pl_token_type_ids, pl_mask
Exemple #2
0
    def test_direct_keras_to_stock_compare(self):
        from tests.ext.tokenization import FullTokenizer
        from tests.ext.modeling import BertModel, BertConfig, get_assignment_map_from_checkpoint

        bert_config = BertConfig.from_json_file(self.bert_config_file)
        tokenizer = FullTokenizer(
            vocab_file=os.path.join(self.bert_ckpt_dir, "vocab.txt"))

        # prepare input
        max_seq_len = 6
        input_str = "Hello, Bert!"
        input_tokens = tokenizer.tokenize(input_str)
        input_tokens = ["[CLS]"] + input_tokens + ["[SEP]"]
        input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
        input_ids = input_ids + [0] * (max_seq_len - len(input_tokens))
        input_mask = [1] * len(input_tokens) + [0] * (max_seq_len -
                                                      len(input_tokens))
        token_type_ids = [0] * len(input_tokens) + [0] * (max_seq_len -
                                                          len(input_tokens))

        input_ids = np.array([input_ids], dtype=np.int32)
        input_mask = np.array([input_mask], dtype=np.int32)
        token_type_ids = np.array([token_type_ids], dtype=np.int32)

        print("   tokens:", input_tokens)
        print(
            "input_ids:{}/{}:{}".format(len(input_tokens), max_seq_len,
                                        input_ids), input_ids.shape,
            token_type_ids)

        s_res = self.predict_on_stock_model(input_ids, input_mask,
                                            token_type_ids)
        k_res = self.predict_on_keras_model(input_ids, input_mask,
                                            token_type_ids)

        np.set_printoptions(precision=9,
                            threshold=20,
                            linewidth=200,
                            sign="+",
                            floatmode="fixed")
        print("s_res", s_res.shape)
        print("k_res", k_res.shape)

        print("s_res:\n {}".format(s_res[0, :2, :10]), s_res.dtype)
        print("k_res:\n {}".format(k_res[0, :2, :10]), k_res.dtype)

        adiff = np.abs(s_res - k_res).flatten()
        print("diff:", np.max(adiff), np.argmax(adiff))
        self.assertTrue(np.allclose(s_res, k_res, atol=1e-6))
    def create_stock_bert_graph(bert_config_file, max_seq_len):
        from tests.ext.modeling import BertModel, BertConfig

        tf_placeholder = tf.compat.v1.placeholder

        pl_input_ids = tf_placeholder(tf.int32, shape=(1, max_seq_len))
        pl_mask = tf_placeholder(tf.int32, shape=(1, max_seq_len))
        pl_token_type_ids = tf_placeholder(tf.int32, shape=(1, max_seq_len))

        bert_config = BertConfig.from_json_file(bert_config_file)
        s_model = BertModel(config=bert_config,
                            is_training=False,
                            input_ids=pl_input_ids,
                            input_mask=pl_mask,
                            token_type_ids=pl_token_type_ids,
                            use_one_hot_embeddings=False)

        return s_model, pl_input_ids, pl_mask, pl_token_type_ids
    def predict_on_stock_model(self, input_ids, input_mask, token_type_ids):
        from tests.ext.modeling import BertModel, BertConfig, get_assignment_map_from_checkpoint

        tf.compat.v1.reset_default_graph()

        tf_placeholder = tf.compat.v1.placeholder

        max_seq_len = input_ids.shape[-1]
        pl_input_ids = tf.compat.v1.placeholder(tf.int32,
                                                shape=(1, max_seq_len))
        pl_mask = tf.compat.v1.placeholder(tf.int32, shape=(1, max_seq_len))
        pl_token_type_ids = tf.compat.v1.placeholder(tf.int32,
                                                     shape=(1, max_seq_len))

        bert_config = BertConfig.from_json_file(self.bert_config_file)
        tokenizer = FullTokenizer(
            vocab_file=os.path.join(self.bert_ckpt_dir, "vocab.txt"))

        s_model = BertModel(config=bert_config,
                            is_training=False,
                            input_ids=pl_input_ids,
                            input_mask=pl_mask,
                            token_type_ids=pl_token_type_ids,
                            use_one_hot_embeddings=False)

        tvars = tf.compat.v1.trainable_variables()
        (assignment_map,
         initialized_var_names) = get_assignment_map_from_checkpoint(
             tvars, self.bert_ckpt_file)
        tf.compat.v1.train.init_from_checkpoint(self.bert_ckpt_file,
                                                assignment_map)

        with tf.compat.v1.Session() as sess:
            sess.run(tf.compat.v1.global_variables_initializer())

            s_res = sess.run(s_model.get_sequence_output(),
                             feed_dict={
                                 pl_input_ids: input_ids,
                                 pl_token_type_ids: token_type_ids,
                                 pl_mask: input_mask,
                             })
        return s_res