def train_data (inputFile, outputDir):
    sess = gpt2.start_tf_sess()
    # train for the input file 

#    tf.variable_scope("sess", reuse=True)
     
    '''
    gpt2.finetune(sess,
        "resource/"+inputFile+".txt",
        model_name=model_name,
        #run_name=inputFile,
        overwrite=True,
        steps=2)
        #_traceback = tf_stack.extract_stack())   # steps is max number of training steps    ''' 
        
        
    # generate 50 examples
    for x in range(0,10):
      #  tf.get_variable_scope().reuse_variables()
        gpt2.load_gpt2(sess)
      #  output = gpt2.generate(sess, return_as_list=True)[0]
        gpt2.generate_to_file(sess, destination_path="newoutputs/" +outputDir+str(uuid.uuid4())+".txt")

#        datalist = gpt2.generate(sess, return_as_list=True)[0]
#        print (datalist)
#        gen_to_file(outputDir, inputFile)
 #       tf.get_variable_scope(reuse=True)
        #tf.get_variable_scope().reuse_variables()

#    tf.get_variable_scope(reuse=True)
#    tf.AUTO_REUSE = True
#    sess.reuse_variables()
   # tf.reset_default_graph()
    gpt2.reset_session(sess,threads=-1,server=None)
    '''
Ejemplo n.º 2
0
 def talk(self, prompt, temp=0.7, words=50):
     #respond to given prompt from loaded run with specified parameters
     session = self.session
     gpt2.generate_to_file(session,
                           include_prefix=False,
                           destination_path='bot_says.txt',
                           length=words,
                           temperature=temp,
                           prefix=prompt)
Ejemplo n.º 3
0
 def loadRun(self):
     #load an existing trained run and generate from it
     session = gpt2.start_tf_sess()
     gpt2.load_gpt2(session, run_name='run1')
     gpt2.generate_to_file(session,
                           include_prefix=False,
                           truncate=".",
                           destination_path='bot_says.txt',
                           length=self.num_words,
                           temperature=self.temperature,
                           prefix=self.prompt)
Ejemplo n.º 4
0
def generate_generic_tweet(session):
    print('generating generic tweets...')
    gpt2.generate_to_file(session,
                          destination_path='../data/generated_tweets.txt',
                          run_name='run2')
    print("tweet generation complete...")
    with open('../data/generated_tweets.txt', 'r',
              encoding='utf-8') as in_file:
        tweets = in_file.read().split(' || ')
    print('tweets split, beginning filter...')
    tweet = choose_and_clean_tweet(tweets)
    #print(tweet)
    return tweet
Ejemplo n.º 5
0
def main():

    ##models:
    #model_name = "124M"
    #model_name = "355M"
    #model_name = "774M"
    #model_name = "1558M"

    model_name = "355M"
    file_name = "champ.txt"

    if not os.path.isdir(os.path.join("models", model_name)):
        print(f"Downloading {model_name} model...")
        gpt2.download_gpt2(
            model_name=model_name
        )  # model is saved into current directory under ./models/124M/

    if not os.path.isfile(file_name):
        print("please provide a filename..")
        exit()

    #GPU config
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.77
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    sess = tf.compat.v1.Session(config=config)

    #sess = gpt2.start_tf_sess() #old for CPU

    print('\n+++ Train model (y)? +++')
    train = input()
    if train == "" or train == "y" or train == 'yes':
        print('---> training model...\n')
        gpt2.finetune(
            sess, file_name, model_name=model_name,
            steps=100)  # steps is max number of training steps - default: 1000
    else:
        print('---> not training model...\n')
    # gpt2.generate(sess) #generate session in file

    ## generate text to file
    gen_file = 'gpt2_gentext_{:%Y%m%d_%H%M%S}.txt'.format(
        datetime.datetime.now(datetime.timezone.utc))
    gpt2.generate_to_file(sess,
                          destination_path=gen_file,
                          length=10000,
                          temperature=0.7,
                          nsamples=1,
                          batch_size=1)
Ejemplo n.º 6
0
def predict_gan():
    separate_funcs          = False
    drange_net              = [-1,1]
    drange_viz              = [-1,1]
    image_grid_size         = (1 ,1)
    image_grid_type         = 'default'
    resume_network          = './pre-trained_weight' # adding the ./ to define the pre-trained-weight folder at root level
    
    np.random.seed(config.random_seed)

    if resume_network:
        print("Resuming weight from:"+resume_network)
        G = Generator(num_channels=3, resolution=128, label_size=0, **config.G)
        G = load_G_weights(G,resume_network,True)

    print(G.summary())

    # Misc init.

    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[1], G.output_shape[2]
            print("w:%d,h:%d"%(w,h))
            image_grid_size = np.clip(int(1920 // w), 3, 16).astype('int'), np.clip(1080 / h, 2, 16).astype('int')
        
        print("image_grid_size:",image_grid_size)
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    result_subdir = misc.create_result_subdir('pre-trained_result', config.run_desc)

    for i in range(1,6):
        snapshot_fake_latents = random_latents(np.prod(image_grid_size), G.input_shape)
        snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents)
        misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'pre-trained_%03d.png'%i), drange=drange_viz, grid_size=image_grid_size)
        
        # use streamlit to show images generated
        # st.image(os.path.join(result_subdir, 'pre-trained_%03d.png'%i))
        st.header('IG Post #' + str(i))
        im = Image.open(os.path.join(result_subdir, 'pre-trained_%03d.png'%i))
        st.image(im.resize((512, 512), Image.ANTIALIAS)) # with gpt2, the images are generated too slow
        # call gpt2-simple on a pre-trained weight
        gen_file = os.path.join(result_subdir,'gpt2_gentext_{:%Y%m%d_%H%M%S}.txt'.format(datetime.utcnow()))
        
        gpt2.generate_to_file(sess, destination_path=gen_file, run_name='tree_run1')
        # read contents of generated text
        with open(gen_file, 'r') as content:
            st.write(content.read())
Ejemplo n.º 7
0
def generate_trending_tweet(session):
    trending = get_trending()
    topic = choice(trending)
    print("generating topical tweets on subject: " + topic)
    gpt2.generate_to_file(session,
                          destination_path='../data/generated_tweets.txt',
                          nsamples=5,
                          run_name='run2',
                          prefix=topic)
    print("topical tweet generation complete...")
    with open('../data/generated_tweets.txt', 'r',
              encoding='utf-8') as in_file:
        tweets = in_file.read().split(' || ')
    tweet = tweets[0]
    if len(tweet) > 280:
        tweet = tweet[:280]
    return tweet
Ejemplo n.º 8
0
def post_tweets():
    session = gpt2.start_tf_sess()
    while True:
        gpt2.generate_to_file(session,
                              destination_path="generated_tweets.txt",
                              run_name="run1")

        with open("generated_tweets.txt", "r", encoding="utf-8") as in_file:
            tweets = in_file.read().split("\n\n")

        while len(tweets) > 1:
            idx = random.randrange(0, len(tweets))
            tweet = tweets.pop(idx).split(" ")
            tweet = " ".join(word for word in tweet if not excluded(word))
            if len(tweet) < 15:
                continue
            print(tweet)
            api.update_status(tweet)
            time.sleep(12 * 60 * 60)
Ejemplo n.º 9
0
def generate_to_file(sess, run_name, subreddit, n=1000, temp=0.7):
    file = subreddit + '_gentext_{:%Y%m%d_%H%M%S}.txt'.format(
        datetime.utcnow())
    path = os.path.join(OUTPUT_DIR, file)

    print('Generating', n, 'submissions to file', file)

    try:
        gpt2.generate_to_file(
            sess,
            destination_path=path,
            length=50,
            temperature=temp,
            nsamples=n,
            batch_size=20,
            top_k=50,  # test
            top_p=0.95,  # test
            checkpoint_dir='generator_models/',
            prefix="<|startoftext|>",
            truncate="<|endoftext|>",
            include_prefix=True,
            run_name=run_name)
    except UnicodeEncodeError:
        print('Stopping early due to an encoding error.')
Ejemplo n.º 10
0
              prefix="LORD",
              nsamples=5,
              batch_size=5
              )

"""For bulk generation, you can generate a large amount of text to a file and sort out the samples locally on your computer. The next cell will generate a generated text file with a unique timestamp.

You can rerun the cells as many times as you want for even more generated texts!
"""

gen_file = 'gpt2_gentext_{:%Y%m%d_%H%M%S}.txt'.format(datetime.utcnow())

gpt2.generate_to_file(sess,
                      destination_path=gen_file,
                      length=500,
                      temperature=0.7,
                      nsamples=100,
                      batch_size=20
                      )

# may have to run twice to get file to download
files.download(gen_file)


"""# Etcetera

If the notebook has errors (e.g. GPU Sync Fail), force-kill the Colaboratory virtual machine and restart it with the command below:
"""

!kill -9 -1
!pip install gpt_2_simple==0.5.4 -t . --no-deps
!pip install toposort
import gpt_2_simple as gpt2
from google.colab import drive

drive.mount('/content/drive')
root_dir = 'drive/MyDrive/Echidna'

!cp /content/drive/MyDrive/Echidna/gpt_2_simple/gpt_2.py /content/gpt_2_simple/gpt_2.py

!ls /content/gpt_2_simple

model_name = "355M"
gpt2.download_gpt2(model_name=model_name)

sess = gpt2.start_tf_sess()
gpt2.copy_checkpoint_from_gdrive(run_name='echidna')
gpt2.load_gpt2(sess, run_name='echidna')

import datetime
gen_file = 'gpt2_gentext_{:%Y%m%d_%H%M%S}.txt'.format(datetime.datetime.utcnow())
name = "Echidna"
brace = '{'
nl = '\n'
gpt2.generate_to_file(sess, destination_path=gen_file, run_name="echidna", prefix=f'<-|start|->{nl}{brace}{nl}    "monster_name": "{name}",{nl}', truncate="<-|end|->", length=10240, temperature=0.9, split_context=0.65)



from google.colab import files
# may have to run twice to get file to download
files.download(gen_file)
Ejemplo n.º 12
0
# to bulk generate text locally using a trained model

import gpt_2_simple as gpt2
from datetime import datetime

sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess, run_name='r1')

num_files = 10

for _ in range(num_files):
    gen_file = 'gpt2_gentext_{:%Y%m%d_%H%M%S}.txt'.format(datetime.utcnow())

    gpt2.generate_to_file(sess,
                          destination_path=gen_file,
                          length=200,
                          temperature=0.85,
                          top_p=0.9,
                          prefix='<|startoftext|>',
                          truncate='<|endoftext|>',
                          include_prefix=False,
                          nsamples=1000,
                          batch_size=20)
Ejemplo n.º 13
0
#/usr/bin/python3
import gpt_2_simple as gpt2
import os

model_name = "124M"
if not os.path.isdir(os.path.join("models", model_name)):
	gpt2.download_gpt2(model_name=model_name)   # model is saved into current directory under /models/124M/

file_name = "./src/alex_jones.txt"

sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
              file_name,
              model_name=model_name,
			  run_name='run2',
              steps=1000,
			  save_every=50,
			  print_every=5,
			  sample_every=10,
			  learning_rate = 0.0001
)

gpt2.generate_to_file(sess)
Ejemplo n.º 14
0
import gpt_2_simple as gpt2

sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess, run_name="montagmodel")

gpt2.generate_to_file(
    sess,
    destination_path='montagarticlernn-temp0-5-100samples.txt',
    nsamples=100,
    temperature=0.5,
    run_name="montagmodel")
Ejemplo n.º 15
0
# Needs tensorflow 1.15 and gpt-2 simple

import gpt_2_simple as gpt2

# Trained and generated on Google Colaboratory
# Make sure the file is in the root directory of Google Drive

gpt2.copy_checkpoint_from_gdrive(run_name='run1')
sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess, run_name='run1')

for i in tqdm(range(int(100))):
    file_name = 'scripts/script_' + str(i) + ".txt"
    gpt2.generate_to_file(
        sess,
        run_name='run1',
        destination_path=file_name,
        length=1000,
        # Higher the temperature Crazier the text.
        temperature=0.8)
Ejemplo n.º 16
0
import gpt_2_simple as gpt2

model_name = "355M"
# Downloads the model
# gpt2.download_gpt2(model_name=model_name)

sess = gpt2.start_tf_sess()

gpt2.load_gpt2(sess)

"""
gpt2.finetune(sess, 
              'data/1k_teigen_tweets.csv',
              model_name=model_name,
              steps=100)

"""
gpt2.generate_to_file(sess,
              length=50,
              temperature=1.0,
              nsamples=5000,
              batch_size=20,
              prefix='<|startoftext|>',
              truncate='<|endoftext|>',
              include_prefix=False,
              destination_path='data/5k_fake_teigen_tweets.txt'
              )
Ejemplo n.º 17
0
gen_file = "output.txt"
sess = gpt2.start_tf_sess()
if TRAIN:
    gpt2.finetune(sess,
                 dataset=file_name,
                 model_name=model,
                 steps=500,
                 restore_from='fresh',
                 run_name=run_name,
                 print_every=10,
                 sample_every=100)
else:
    gpt2.load_gpt2(sess, run_name=run_name)

# gpt2.generate(sess, run_name=run_name,
#              length=100,
#              prefix="<|startoftext|>",
#              truncate="<|endoftext|>",
#              include_prefix=False)
gpt2.generate_to_file(sess,
                     destination_path=gen_file,
                     length=100,
                     temperature=1.0,
                     nsamples=100,
                     batch_size=20,
                     prefix="<|startoftext|>",
                     truncate="<|endoftext|>",
                     include_prefix=False,
                     sample_delim='',
                     run_name=run_name
                     )
Ejemplo n.º 18
0
import gpt_2_simple as gpt2

sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess, run_name="montagmodel")

gpt2.generate_to_file(sess,
                      destination_path='article.txt',
                      temperature=0.5,
                      run_name="montagmodel")
Ejemplo n.º 19
0
import gpt_2_simple as gpt2
from datetime import datetime

TWEETS = "squatchssb_tweets.csv"
OUTPUT = "deep_squatch.txt"

# Download 355M model
gpt2.download_gpt2(model_name="355M")

sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
              dataset=TWEETS,
              model_name='355M',
              steps=2000,
              restore_from='fresh',
              run_name='deep-squatch',
              print_every=50,
              sample_every=500,
              save_every=500)

gpt2.generate_to_file(sess,
                      length=140,
                      temperature=1.21,
                      prefix='<|startoftext|>',
                      truncate='<|endoftext|>',
                      include_prefix=False,
                      nsamples=2000,
                      batch_size=20,
                      destination_path=OUTPUT)
Ejemplo n.º 20
0
    gpt2.load_gpt2(sess, checkpoint_dir=checkpoint_path)

    # create a prompt text for the text generation
    #prompt_text = "Python is awesome"
    prompt_text = st.text_input(label="Enter your prompt text...",
                                value="This generator predict about")

    with st.spinner("AI is at Work........"):
        # text generation
        gen_file = 'gpt2_gentext_{:%Y%m%d_%H%M%S}.txt'.format(
            datetime.utcnow())

        gpt2.generate_to_file(sess,
                              checkpoint_dir=checkpoint_path,
                              prefix=prompt_text,
                              destination_path=gen_file,
                              length=500,
                              temperature=1,
                              nsamples=1,
                              batch_size=1)
    st.success("AI Successfully generated the below text ")
    st.balloons()
    file = open(gen_file)
    text = file.read()
    print('Corpus length in characters=', len(text))
    print((text).encode('utf8'))
    st.markdown(text)

elif ((dataset_name == "Robert Jordan") and (trained_model == "GPT2")):
    st.title("Robert Jordan")
    checkpoint_path = "/content/drive/MyDrive/Project folder/Author wise text generation using GPT/Robert Jordan/checkpoint/"
    checkpoint_dir = os.path.dirname(checkpoint_path)
Ejemplo n.º 21
0
import sys
import gpt_2_simple as gpt2

if len(sys.argv) <= 1:
    exit("Need prompt")

prompt = sys.argv[1]

sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess)

gpt2.generate_to_file(sess,
                      run_name='run1',
                      length=500,
                      temperature=0.7,
                      nsamples=5,
                      prefix=prompt)
Ejemplo n.º 22
0
import sys
import gpt_2_simple as gpt2

if len(sys.argv) <= 1:
    exit("Need prompt")

prompt = sys.argv[1]

sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess)

gpt2.generate_to_file(sess,
                      run_name='run1',
                      length=540,
                      temperature=0.9,
                      nsamples=1,
                      prefix=prompt,
                      destination_path="genText_NOT_OUTPUT.txt")
Ejemplo n.º 23
0
#!/usr/bin/python3
import gpt_2_simple as gpt2
import os

models = [f for f in os.listdir("/pfs/train")]

model_dir = os.path.join("/pfs/train", models[0])
# can't tell gpt2 where to read from, so we chdir
os.chdir(model_dir)

sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess)

out = os.path.join("/pfs/out", models[0])
gpt2.generate_to_file(sess, destination_path=out, prefix="<|startoftext|>",
                      truncate="<|endoftext|>", include_prefix=False,
                      length=280, nsamples=200, temperature=1.0)
Ejemplo n.º 24
0
def main():
    
    #  Create a parser to parse user input
    def parse_arguments():
        parser = argparse.ArgumentParser(description='Program for running several SeqGan applications.')
        parser.add_argument('app', metavar='application', type=str, choices=['obama', 'haiku', 'synth'],
                        help='Enter either \'obama\' or \'haiku\'')
        parser.add_argument('n', metavar='num_steps', type=int, help= "The number of training steps")
        parser.add_argument('-numeat', metavar="num_eat", type = int, default = 500,
                        help = "For synthetic data generation. Determines number of eaters in vocab.")
        parser.add_argument('-numfeed', metavar="num_feed", type = int, default = 500,
                        help = "For synthetic data generation. Determines number of feeders in vocab.")
        parser.add_argument('-numsent', metavar="num_sent", type = int, default = 10000,
                        help = "For synthetic data generation. Determines number of sentences generated.")
        args = parser.parse_args()

        synth_gen_params = ("NA", "NA", "NA")
        if args.app == "synth":
            synth_gen_params = (args.numsent, args.numfeed, args.numeat)
            generate_random_sents("../data/synth/input.txt", args.numsent, args.numfeed, args.numeat)

        task = load_task(args.app)
    
        return task, args.n synth_gen_params
    
    task, num_steps, SYNTH_GEN_PARAMS = parse_arguments()

    model_name = "124M"
    if not os.path.isdir(os.path.join("models", model_name)):
        print(f"Downloading {model_name} model...")
        gpt2.download_gpt2(model_name=model_name)   # model is saved into current directory under /models/124M/

    file_name = task.train_file
    if not os.path.isfile(file_name):
        print("Training data not present for " task.name)
        sys.exit(0)
        

    sess = gpt2.start_tf_sess()
    gpt2.finetune(sess,
                file_name,
                model_name=model_name,
                steps=num_steps, restore_from="fresh")   # steps is max number of training steps


    gpt2.generate_to_file(sess,destination_path=task.eval_file,  nsamples = 10)



    #Writing results to CSV
    with open(task.eval_file) as f:
        generated = []
        for line in f:
            line = line.strip().split()
            generated.append(line)
        generated = task.vocab.decode(generated)
        f.close()

    with open(task.test_file) as f:
        references = []
        for line in f:
            line = line.strip().split()
            references.append(line)
        references = task.vocab.decode(references)  
        f.close()      

        
    if not os.path.exists("./results.csv"):
        os.mknod("./results.csv")

    with open("./results.csv", 'a') as csvfile:
        fieldnames = ["name", "task_name", "num_gen", "num_disc", "num_adv",
                    "num_sents", "num_feeders", "num_eaters", "BLEU", "prop_valid"]
        writer = csv.DictWriter(csvfile, fieldnames = fieldnames)
        csvfile.seek(0, os.SEEK_END) # go to end of file
        if not csvfile.tell(): # if current position is != 0)
            writer.writeheader()

        blue = corpus_bleu([references]*len(generated), generated)
        print("Run with args {} {} {}: BLEUscore = {}\n".format(gen_n, disc_n, adv_n, blue))
        
        prop = "NA"

        if task.name == "synth":
            total_correct = 0
            for sentence in generated:
                if is_valid_phrase(sentence):
                    total_correct +=1
            prop = total_correct/len(generated)

        writer.writerow({"name": MODEL_STRING, "task_name": task.name,  "num_gen": gen_n, 
                        "num_disc":disc_n, "num_adv": adv_n, "num_sents":SYNTH_GEN_PARAMS[0],
                        "num_feeders":SYNTH_GEN_PARAMS[1], "num_eaters":SYNTH_GEN_PARAMS[2],
                        "BLEU": blue, "prop_valid": prop})
        f.close()
Ejemplo n.º 25
0
def generate_trending_tweet():
    topics = ["Biden", "Trump"]
    topic = choice(topics)
    # this is just for testing repeat topics- remove before deployment
    print("generating topical tweets on subject: " + topic)

    # update the text file with current tweets
    file_name = '../data/'+topic+'.txt'
    topical_tweets = get_topic_tweets(topic, 5000)
    t_tweet_string = " || ".join(topical_tweets)

    with open(file_name, 'w') as f:
        f.write(t_tweet_string)

    # train model
    print("training new model on scraped text for topic : "+topic)
    sess = gpt2.start_tf_sess()

    if not os.path.exists('checkpoint/'+topic):
        # train fresh model
        print("training fresh model- none found...")
        gpt2.finetune(sess,
                      dataset=file_name,
                      model_name=model_name,
                      steps=200,
                      restore_from='fresh',
                      run_name=topic,
                      print_every=1)
    else:
        # update existing model
        print("updating existing model with short run on new tweets...")
        gpt2.finetune(sess,
                      dataset=file_name,
                      model_name=model_name,
                      run_name=topic,
                      steps=100,
                      restore_from='latest',
                      print_every=1)
    # generate tweet
    print("beginning to generate tweets...")
    gpt2.generate_to_file(sess,
                          length=400,
                          destination_path='../data/generated_tweets.txt',
                          nsamples=10,
                          run_name=topic,
                          prefix=topic)
    print('done generating tweets... ')
    # reset the session to prevent errors on loop
    gpt2.reset_session(sess=sess)
    # return 1 tweet
    with open('../data/generated_tweets.txt', 'r') as f:
        texts = f.read().split('====================')
    tweets = []
    for text in texts:
        # by just taking the first tweet, we're sure we have the seed text
        tweeters = text.split(' || ')
        for tweet in tweeters:
            if topic in tweet:  # ensure it contains the topic string
                tweet = tweet.split(" ")
                # remove links
                tweet = " ".join(
                    word for word in tweet if not has_prefix(word))
                # ensure it's not just the topic word only
                if len(tweet) > len(topic)+4 & len(tweet) <= 280:
                    tweets.append(tweet)
            else:
                continue
    #print("Potential tweets:\n"+ " \n\n ".join(tweets))
    tweet = choice(tweets)
    if len(tweet) > 280:
        tweet = tweet[:280]
    return tweet
Ejemplo n.º 26
0
def generate_trending_tweet():
    # pick a topic
    trending = get_trending()
    topic = choice(trending)
    print("generating tweets on topic: " + topic)
    # fetch tweets on topic
    file_name = "../data/" + topic + ".txt"
    topical_tweets = get_topic_tweet(topic, 1000)
    tweet_string = " || ".join(topical_tweets)
    with open(file_name, "w") as f:
        f.write(tweet_string)
    # train a model on new tweets
    sess = gpt2.start_tf_sess()
    if not os.path.exists("checkpoint/" + topic):
        gpt2.finetune(
            sess,
            dataset=file_name,
            model_name=model_name,
            steps=2,
            restore_from="fresh",
            run_name=topic,
            print_every=1,
        )
    else:
        gpt2.finetune(
            sess,
            dataset=file_name,
            model_name=model_name,
            steps=1,
            restore_from="latest",
            run_name=topic,
            print_every=1,
        )
    # generate text with the new model
    gpt2.generate_to_file(
        sess,
        length=400,
        destination_path="../data/generated_tweets.txt",
        nsamples=5,
        run_name=topic,
        prefix=topic,
    )
    gpt2.reset_session(sess)
    # filter and return 1 valid tweet from the gerated text
    with open("../data/generated_tweets.txt", "r") as f:
        texts = f.read().split("====================")
    tweets = []
    for text in texts:
        tweeters = text.split(" || ")
        for tweet in tweeters:
            if topic in tweet:
                tweet = tweet.split(" ")
                tweet = " ".join(word for word in tweet
                                 if not filter_links(word))
                if len(tweet) > len(topic) + 4:
                    tweets.append(tweet)
            else:
                continue
    tweet = choice(tweets)
    if len(tweet) > 280:
        tweet = tweet[:280]
    return tweet