def train_gpt2(csv_path='gpt2_train_data/bertffn_crossentropy.csv', steps=100000, batch_size=1): tf.compat.v1.disable_v2_behavior() if not os.path.exists('models/117M'): gpt2.download_gpt2() sess = gpt2.start_tf_sess() gpt2.finetune(sess, csv_path, steps=steps, batch_size=batch_size)
def __init__(self, pretrained_path='pubmed_pmc_470k/', ffn_weight_file=None, bert_ffn_weight_file='models/bertffn_crossentropy/bertffn', embedding_file='qa_embeddings/bertffn_crossentropy.pkl'): super(GenerateQADoc, self).__init__() tf.compat.v1.disable_eager_execution() self.sess = gpt2.start_tf_sess() gpt2.load_gpt2(self.sess) self.embed_sess = gpt2.start_tf_sess() with self.embed_sess.as_default(): self.qa_embed = QAEmbed(pretrained_path=pretrained_path, ffn_weight_file=ffn_weight_file, bert_ffn_weight_file=bert_ffn_weight_file, with_answer=False, load_pretrain=False) self.faiss_topk = FaissTopK(embedding_file)
import tensorflow.compat.v1 as tf from gpt_2 import gpt2 tf.disable_eager_execution() csv_path = 'data/GPT2_data_FFNN.csv' sess = gpt2.start_tf_sess() gpt2.load_gpt2(sess) gpt2.generate( sess, prefix="`QUESTION: What is the best treatment for the flu? `ANSWER:")