Ejemplo n.º 1
0
def train_gpt2(csv_path='gpt2_train_data/bertffn_crossentropy.csv',
               steps=100000,
               batch_size=1):
    tf.compat.v1.disable_v2_behavior()
    if not os.path.exists('models/117M'):
        gpt2.download_gpt2()

    sess = gpt2.start_tf_sess()
    gpt2.finetune(sess, csv_path, steps=steps, batch_size=batch_size)
Ejemplo n.º 2
0
    def __init__(self,
                 pretrained_path='pubmed_pmc_470k/',
                 ffn_weight_file=None,
                 bert_ffn_weight_file='models/bertffn_crossentropy/bertffn',
                 embedding_file='qa_embeddings/bertffn_crossentropy.pkl'):
        super(GenerateQADoc, self).__init__()
        tf.compat.v1.disable_eager_execution()
        self.sess = gpt2.start_tf_sess()
        gpt2.load_gpt2(self.sess)
        self.embed_sess = gpt2.start_tf_sess()
        with self.embed_sess.as_default():
            self.qa_embed = QAEmbed(pretrained_path=pretrained_path,
                                    ffn_weight_file=ffn_weight_file,
                                    bert_ffn_weight_file=bert_ffn_weight_file,
                                    with_answer=False,
                                    load_pretrain=False)

        self.faiss_topk = FaissTopK(embedding_file)
Ejemplo n.º 3
0
import tensorflow.compat.v1 as tf
from gpt_2 import gpt2

tf.disable_eager_execution()

csv_path = 'data/GPT2_data_FFNN.csv'

sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess)

gpt2.generate(
    sess, prefix="`QUESTION: What is the best treatment for the flu? `ANSWER:")