def load_stock_model(model_dir, max_seq_len):
        from tests.ext.modeling import BertModel, BertConfig, get_assignment_map_from_checkpoint

        tf.compat.v1.reset_default_graph(
        )  # to scope naming for checkpoint loading (if executed more than once)

        bert_config_file = os.path.join(model_dir, "bert_config.json")
        bert_ckpt_file = os.path.join(model_dir, "bert_model.ckpt")

        pl_input_ids = tf.compat.v1.placeholder(tf.int32,
                                                shape=(1, max_seq_len))
        pl_mask = tf.compat.v1.placeholder(tf.int32, shape=(1, max_seq_len))
        pl_token_type_ids = tf.compat.v1.placeholder(tf.int32,
                                                     shape=(1, max_seq_len))

        bert_config = BertConfig.from_json_file(bert_config_file)

        s_model = BertModel(config=bert_config,
                            is_training=False,
                            input_ids=pl_input_ids,
                            input_mask=pl_mask,
                            token_type_ids=pl_token_type_ids,
                            use_one_hot_embeddings=False)

        tvars = tf.compat.v1.trainable_variables()
        (assignment_map,
         initialized_var_names) = get_assignment_map_from_checkpoint(
             tvars, bert_ckpt_file)
        tf.compat.v1.train.init_from_checkpoint(bert_ckpt_file, assignment_map)

        return s_model, pl_input_ids, pl_token_type_ids, pl_mask
Exemplo n.º 2
0
    def predict_on_stock_model(self, input_ids, input_mask, token_type_ids):
        from tests.ext.modeling import BertModel, BertConfig, get_assignment_map_from_checkpoint

        tf.compat.v1.reset_default_graph()

        tf_placeholder = tf.compat.v1.placeholder

        max_seq_len = input_ids.shape[-1]
        pl_input_ids = tf.compat.v1.placeholder(tf.int32,
                                                shape=(1, max_seq_len))
        pl_mask = tf.compat.v1.placeholder(tf.int32, shape=(1, max_seq_len))
        pl_token_type_ids = tf.compat.v1.placeholder(tf.int32,
                                                     shape=(1, max_seq_len))

        bert_config = BertConfig.from_json_file(self.bert_config_file)
        tokenizer = FullTokenizer(
            vocab_file=os.path.join(self.bert_ckpt_dir, "vocab.txt"))

        s_model = BertModel(config=bert_config,
                            is_training=False,
                            input_ids=pl_input_ids,
                            input_mask=pl_mask,
                            token_type_ids=pl_token_type_ids,
                            use_one_hot_embeddings=False)

        tvars = tf.compat.v1.trainable_variables()
        (assignment_map,
         initialized_var_names) = get_assignment_map_from_checkpoint(
             tvars, self.bert_ckpt_file)
        tf.compat.v1.train.init_from_checkpoint(self.bert_ckpt_file,
                                                assignment_map)

        with tf.compat.v1.Session() as sess:
            sess.run(tf.compat.v1.global_variables_initializer())

            s_res = sess.run(s_model.get_sequence_output(),
                             feed_dict={
                                 pl_input_ids: input_ids,
                                 pl_token_type_ids: token_type_ids,
                                 pl_mask: input_mask,
                             })
        return s_res
Exemplo n.º 3
0
    def create_stock_bert_graph(bert_config_file, max_seq_len):
        from tests.ext.modeling import BertModel, BertConfig

        tf_placeholder = tf.compat.v1.placeholder

        pl_input_ids = tf_placeholder(tf.int32, shape=(1, max_seq_len))
        pl_mask = tf_placeholder(tf.int32, shape=(1, max_seq_len))
        pl_token_type_ids = tf_placeholder(tf.int32, shape=(1, max_seq_len))

        bert_config = BertConfig.from_json_file(bert_config_file)
        s_model = BertModel(config=bert_config,
                            is_training=False,
                            input_ids=pl_input_ids,
                            input_mask=pl_mask,
                            token_type_ids=pl_token_type_ids,
                            use_one_hot_embeddings=False)

        return s_model, pl_input_ids, pl_mask, pl_token_type_ids