def main(args): """ Load generated model checkpoints from by default in /checkpoint/run1 and generate new text """ try: config_path = project_path + "/" + args.config input_data_path = project_path + "/" + args.input output_data_path = project_path + "/" + args.output config = load_config(config_path) # load data df = read_csv(input_data_path) lines = list(df['raw_line']) random.seed(config['generate']['random_seed']) sample_seeds = random.choices(lines, k=config['generate']['num']) sess = gpt2.start_tf_sess() gpt2.load_gpt2(sess) pred = [] for i in sample_seeds: out = gpt2.generate(sess, prefix=i, **config['generate']['generator']) pred.append(out) pred_df = pd.DataFrame(pred, columns=['raw_line']) save_csv(pred_df, output_data_path) except Exception as e: logger.error( "Unexpected error occurred when generating dialogues with gpt2: " + str(e))
async def homepage(request): global generate_count global sess if request.method == 'GET': params = request.query_params elif request.method == 'POST': params = await request.json() elif request.method == 'HEAD': return JSONResponse({'text': ''}, headers=response_header) print('+++++++++++++++') print(params) text = gpt2.generate(sess, length=100, temperature=float(params.get('temperature', 0.7)), prefix=params.get('prefix', '')[:500], return_as_list=True)[0] generate_count += 1 if generate_count == 8: # Reload model to prevent Graph/Session from going OOM tf.reset_default_graph() sess.close() sess = gpt2.start_tf_sess(threads=1) gpt2.load_gpt2(sess) generate_count = 0 gc.collect() text = re.split('\n', text) return JSONResponse({'text': text}, headers=response_header)
def loader(game_name): print(Fore.GREEN) l = pyfiglet.figlet_format("Loading...", font="slant") print(Style.BRIGHT + l) print(Fore.RESET) sess = gpt2.start_tf_sess() gpt2.load_gpt2(sess) input1 = "I am Leo" stories = gpt2.generate(sess, length=250, temperature=0.7, prefix=input1, nsamples=5, batch_size=5, top_k=40, return_as_list=True) print(Fore.RESET + Style.RESET_ALL) story = "" temp = stories[3].split(".") del temp[-1] for i in temp: story = story + i + '.' return str(story)
def generate_story(): global sess # input text is request.form['input'] try: tf.reset_default_graph() sess.close() sess = gpt2.start_tf_sess() gpt2.load_gpt2(sess, run_name=request.form['genre']) print("GENRE") print(request.form['genre']) generated_text = gpt2.generate(sess, run_name=request.form['genre'], length=200, temperature=0.8, prefix=str(request.form['input']), nsamples=1, batch_size=1, return_as_list=True)[0] return Response(response=generated_text, status=200) except: traceback.print_exc(file=sys.stdout) print('aborting gen text') abort(404)
def main(): args = parse_args() if not args.file: logger.error("No file entered. Use -f flag.") exit() filename = Path(args.file).stem logger.debug("Download model") gpt2.download_gpt2() logger.debug("Starting GPT-2 session") sess = gpt2.start_tf_sess() logger.debug("Finetuning model") gpt2.finetune(sess, args.file, steps=args.iteration) Path("Exports").mkdir(parent=False, exist_ok=True) logger.debug("Generating text") while True: generated_text = gpt2.generate(sess, return_as_list=True, temperature=args.temperature)[0] with open(f"Exports/{filename}_{args.temperature}_gpt2simple.txt", "a") as f: test_hour = datetime.datetime.now().strftime("%Y/%m/%d %H:%M") f.write(f"{test_hour}\n") for i in generated_text: f.write(f"{i}\n") logger.info("Runtime : %.2f seconds" % (time.time() - temps_debut))
def train(input_file): if os.path.exists('models/temp'): shutil.rmtree('models/temp') if os.path.exists('models/124M'): pass else: download() sess = gpt2.start_tf_sess() model_name = '124M' model_dir = 'models/' training_dir = 'src/training_data/' file_name = input_file.split('.')[0] gpt2.finetune(sess, training_dir+input_file, model_name=model_name, checkpoint_dir=model_dir+'temp/', run_name='', steps=1) gpt2.reset_session(sess) if os.path.exists('models/latest'): shutil.rmtree('models/latest') shutil.copytree('models/temp','models/latest') # shutil.rmtree('models/temp')
def train_data (inputFile, outputDir): sess = gpt2.start_tf_sess() # train for the input file # tf.variable_scope("sess", reuse=True) ''' gpt2.finetune(sess, "resource/"+inputFile+".txt", model_name=model_name, #run_name=inputFile, overwrite=True, steps=2) #_traceback = tf_stack.extract_stack()) # steps is max number of training steps ''' # generate 50 examples for x in range(0,10): # tf.get_variable_scope().reuse_variables() gpt2.load_gpt2(sess) # output = gpt2.generate(sess, return_as_list=True)[0] gpt2.generate_to_file(sess, destination_path="newoutputs/" +outputDir+str(uuid.uuid4())+".txt") # datalist = gpt2.generate(sess, return_as_list=True)[0] # print (datalist) # gen_to_file(outputDir, inputFile) # tf.get_variable_scope(reuse=True) #tf.get_variable_scope().reuse_variables() # tf.get_variable_scope(reuse=True) # tf.AUTO_REUSE = True # sess.reuse_variables() # tf.reset_default_graph() gpt2.reset_session(sess,threads=-1,server=None) '''
def get(self, context=''): run_name = 'run3' sess = gpt2.start_tf_sess() gpt2.load_gpt2(sess, run_name=run_name) results = gpt2.generate(sess, run_name=run_name, prefix=context, nsamples=10, length=200, batch_size=10, temperature=1, top_k=40, include_prefix=True, return_as_list=True) all_tweets = [] for result in results: subtweets = result.splitlines() all_tweets = list(set(all_tweets + subtweets)) with io.open('tweets_unseparated.txt', 'r', encoding="utf-8") as tweet_file: original_tweets = tweet_file.readlines() original_tweets = [x.strip() for x in original_tweets] all_tweets = list(set(all_tweets) - set(original_tweets)) result = {'predicted_text': all_tweets} return jsonify(result)
def generator(data_1, data_2, data_3): game_name = data_1 epoch = data_2 model = data_3 db = Database.readBlobData(game_name) file_name = "/content/app/data/data.txt" g = pyfiglet.figlet_format("Generating world...", font="slant") print(Fore.BLACK + Style.DIM) print(Fore.GREEN) print(Style.BRIGHT + g) print(Fore.BLACK + Style.DIM) sess = gpt2.start_tf_sess() sample = gpt2.finetune(sess, dataset=file_name, model_name=data_3, steps=epoch, restore_from='fresh', run_name="run1", print_every=1, sample_every=epoch, save_every=epoch) return sample
def fit(self, input_path, reset = True, overwrite = False, num_steps = 1000, batch_size = 1, print_every = 10, sample_every = 200, save_every = 300, restore_from = 'fresh', run_name = 'reddit_comment_generator'): if reset: tf.reset_default_graph() self.tf_sess = gpt2.start_tf_sess() if overwrite and restore_from != 'latest': restore_from = 'latest' # Finetuning the model on new data gpt2.finetune(self.tf_sess, dataset = input_path, batch_size = batch_size, model_name = self.model_type, steps = num_steps, restore_from = restore_from, run_name = run_name, print_every = print_every, sample_every = sample_every, save_every = save_every)
def generate_comments(self, user_input, bert_model_prediction, length = 200, temperature = 0.7, num_samples = 2, batch_size = 1, top_k = 0, top_p = 0, run_name = 'reddit_comment_generator', checkpoint_dir = './GPT2/checkpoint', truncate_string = None): if not self.tf_sess: self.tf_sess = gpt2.start_tf_sess() # Generate samples subreddit_id = self.SubredditMapping[bert_model_prediction] prefix = '****S ' + subreddit_id + '\n' + user_input + '\n' + '****ES' comments = gpt2.generate(self.tf_sess, length = length, temperature = temperature, prefix = prefix, nsamples = num_samples, batch_size = batch_size, run_name = run_name, top_k = top_k, top_p = top_p, return_as_list = True, checkpoint_dir = checkpoint_dir, truncate = truncate_string) index = 0 shuffle(self.Names) ans = '' for text in comments: text = text.split('\n') L = len(text) i = 0 while ('****TC' not in text[i]): text[i] = '' i += 1 start = i while(i < L and '****S' not in text[i]): if '****TC' in text[i]: text[i] = '<strong>' + str(self.Names[index]) + '</strong>' index += 1 elif '****ETC' in text[i]: text[i] = '' i += 1 text = text[start:i] text = '\n'.join(text) if not ans: ans = text else: ans = ans + '\n\n' + text return ans
def main(): helix = twitch.Helix('', use_cache=True) global lastmsg, msg lastmsg = datetime.datetime.now() msg = queue.Queue(100) sess = gpt2.start_tf_sess() gpt2.load_gpt2(sess, run_name=run_name) while not msg.full(): newmsg = genmsg(sess) print(newmsg) msg.put(newmsg) for channel in textchannels: chat = twitch.Chat(channel="#" + channel, nickname='WoodenLongboard', oauth="", helix=helix) chats[channel] = chat chats[channel].subscribe(handle_message) print("Finished init") while True: if not msg.full(): msg.put(genmsg(sess))
def prepare_fine_tuning(self, file_name: str): """ prepare_fine_tuning : Personnalise et regle le modèle pour l'entrainer sur notre dataset. Args: file_name (str): Nom du fichier d'entrée. """ if not os.path.isdir(os.path.join("models", self.model_name)): print(f"Downloading {self.model_name} model...") gpt2.download_gpt2( model_name=self.model_name ) # model is saved into current directory under /models/124M/ sess = gpt2.start_tf_sess() gpt2.finetune( sess, dataset=file_name, model_name=self.model_name, steps=1000, restore_from="fresh", run_name=self.run_name, print_every=10, sample_every=200, save_every=500, )
def generate(): first_line = request.args['firstLine'] first_line = '<|startoftext|> ' + first_line.lower() sess = gpt2.start_tf_sess(threads=1) gpt2.load_gpt2(sess, run_name="run1", checkpoint_dir="checkpoint") output = [''] while (len(output[0]) <= len(first_line) + 30): output = gpt2.generate(sess, run_name='run1', checkpoint_dir='checkpoint', model_dir='models', sample_dir='samples', return_as_list=True, length=120, temperature=0.7, prefix=first_line, truncate="<|endoftext|>", include_prefix=True) tf.reset_default_graph() sess.close() gc.collect() data = output[0].replace('<|startoftext|> ', '') return json.dumps({"data": data})
def __init__(self, source, num_words, prompt="DEFAULT", temp=0.7): #separating blocks into sentence tokens nltk.download('punkt') self.tokenizer = nltk.data.load('tokenizers/punkt/english.pickle') #where the training data is stored self.source = source #deviation from original dataset self.temperature = temp #length of output self.num_words = num_words #user input self.prompt = prompt if self.prompt == "DEFAULT": self.prompt = "The quick brown fox jumped over the lazy dog." files = os.listdir(os.getcwd()) if "models" not in files: #first-time runthrough self.setupModel() print('Setup Complete') if "checkpoint" not in files: #story generator, give parameters if necessary self.trainGenerator() print('Training Complete') tf.reset_default_graph() self.session = gpt2.start_tf_sess() gpt2.load_gpt2(self.session, run_name='run1') print('Done')
async def _generate_samples(self, model: typing.Optional[str] = None, max_size=31): if model == None: return cog_data_path = data_manager.cog_data_path(self) model_path: pathlib.Path = cog_data_path / "models" / model if not model_path.exists(): log.error(f"Model {model} not found in {str(cog_data_path)}") return tf_session = gpt_2_simple.start_tf_sess() gpt_2_simple.load_gpt2( tf_session, checkpoint_dir=str(cog_data_path / "checkpoints"), model_name=model, model_dir=str(model_path.parent), ) while True: new_sample = gpt_2_simple.generate( tf_session, return_as_list=True, truncate="<|endoftext|>", temperature=1.0, )[0] async with self.full: if len(self.samples) >= max_size: log.info("Cache full, waiting for next command") await self.full.wait() self.samples.append(new_sample) async with self.empty: if len(self.samples) == 1: self.empty.notify()
def on_status(self, tweet): print("Received status.") if self.is_mentioned(tweet): username = tweet.user.screen_name text = str(tweet.text) # do I need to remove @sarcastic_trump before generating a new prediction? text_without_self_username = text.replace("@sarcastic_trump", "") try: generated_tweet_from_text_as_prefix = self.generate_gpt2_tweet_using_prefix( prefix=text_without_self_username) except: tf.reset_default_graph() self.sess = gpt2.start_tf_sess() gpt2.load_gpt2(self.sess, run_name='trump_clean_small') generated_tweet_from_text_as_prefix = self.generate_gpt2_tweet_using_prefix( prefix=text_without_self_username) # remove the extra lines without punctuation. Might also remove hashtags # ? tweet_without_extra_lines = self.remove_extra_lines( generated_tweet_from_text_as_prefix[0]) print(f"{username}: {text}") if len(generated_tweet_from_text_as_prefix[0]) + len( username) + 5 > 240: self.api.update_status( f"Hey @{username}, {tweet_without_extra_lines[0:240-len(username)-5]}", in_reply_to_status_id=tweet.id) else: self.api.update_status( f"Hey @{username}, {tweet_without_extra_lines}", in_reply_to_status_id=tweet.id)
def finetune(self, corpus, return_text=True): """ Returns generated text sample Parameters ---------- arg: corpus (object) - desc: Custom dataset text file arg: return_text (bool) - default: True - desc: Toggles whether to return custom-generated text in an array after fine-tuning Returns: Generated string in an array """ sess = gpt2.start_tf_sess() gpt2.finetune(sess, corpus, model_name=self.model_name, steps=1000) # steps is max number of training steps if return_text: text = gpt2.generate(sess, return_as_list=True) return text else: gpt2.generate(sess)
async def generate(input: str = "", auth: str = ""): global sess, generate_count if auth != AUTH_KEY: return "Invalid auth token provided" result = gpt2.generate( sess, run_name="run1", length=300, temperature=0.9, prefix=input, top_p=100, nsamples=1, batch_size=1, include_prefix=False, return_as_list=True, )[0] generate_count += 1 if generate_count == 12: # Reload model to prevent Graph/Session from going OOM tf.reset_default_graph() sess.close() sess = gpt2.start_tf_sess(threads=8) gpt2.load_gpt2(sess, run_name="run1") generate_count = 0 return result
def generate(prefix, input_file, similarity_threshold, nsamples, length, temperature, k): # load the quotes used for fine-tuning with open(input_file, 'r') as f: originals = f.readlines() original_quotes = [ originals[i].strip() for i in range(1, len(originals), 3) ] # generate a batch of quotes sess = gpt2s.start_tf_sess() gpt2s.load_gpt2(sess) samples = gpt2s.generate(sess, nsamples=nsamples, length=length, temperature=temperature, top_k=k, prefix=prefix + '\n', return_as_list=True) # filter the samples quotes = [] for s in samples: title, body = s.split('\n')[:2] is_long = len(body.split(' ')) > 3 is_novel = all( similar(body, x) < similarity_threshold for x in original_quotes) if is_long and is_novel: quotes.append(body) return quotes
def gpt2_finetune(hparams): info_print("Model finetuning, please wait. (Press Ctrl+C to exit early)") sess = gpt2.start_tf_sess() # input check if not os.path.exists( os.path.join(hparams.gpt2_model_dir, hparams.gpt2_model_name)): raise FileNotFoundError( "The specified gpt2 pretrained model doesn't exist, please restore the default params." ) # clear checkpoint dir model_path = os.path.join(hparams.finetuned_model_dir, hparams.finetuned_model_name) if os.path.exists(model_path): shutil.rmtree(model_path) gpt2.finetune(sess=sess, dataset=hparams.data_path, model_dir=hparams.gpt2_model_dir, model_name=hparams.gpt2_model_name, checkpoint_dir=hparams.finetuned_model_dir, run_name=hparams.finetuned_model_name, multi_gpu=hparams.multi_gpu, steps=hparams.steps)
def start(): print("Starting") start_time = datetime.datetime.now() sess = gpt2.start_tf_sess() gpt2.load_gpt2(sess, model_name=model_name) text = gpt2.generate( sess, model_name=model_name, prefix= "In a shocking finding, scientist discovered a herd of unicorns living in a remote, " "previously unexplored valley, in the Andes Mountains. Even more surprising to the " "researchers was the fact that the unicorns spoke perfect English.", length=100, temperature=0.7, top_p=0.9, return_as_list=True) total_time = datetime.datetime.now() - start_time print("Total time required is = ", total_time) print(text) return " ".join(text)
def finetune( model_name: str, text_path: str, num_steps: int, sample_length: int, save_every: Optional[int], ) -> None: # Download the model if it is not present if not os.path.isdir(os.path.join("models", model_name)): print(f"Downloading {model_name} model...") gpt2.download_gpt2(model_name=model_name) sess = gpt2.start_tf_sess() if save_every is None: save_every = int(num_steps / 4) gpt2.finetune( sess, text_path, model_name=model_name, steps=num_steps, sample_length=sample_length, save_every=save_every, ) # steps is max number of training steps gpt2.generate(sess)
def main(): if len(sys.argv) < 4: print( 'Usage: python run_generator.py RUN_NAME SUBREDDIT NO_SAMPLES (TEMPERATURE)' ) return run_name = sys.argv[1] subreddit = sys.argv[2] try: no_samples = int(sys.argv[3]) except Exception as e: print(e) print('Third argument should be an integer') return temperature = 1 if len(sys.argv) >= 5: try: temperature = float(sys.argv[4]) except Exception as e: print(e) print('Fourth argument should be a float') return update_checkpoint(run_name) sess = gpt2.start_tf_sess() gpt2.load_gpt2(sess, run_name=run_name, checkpoint_dir='generator_models') generate_to_file(sess, run_name, subreddit, n=no_samples, temp=temperature) print('Done.')
async def homepage(request): global generate_count global sess if request.method == 'GET': params = request.query_params elif request.method == 'POST': params = await request.json() elif request.method == 'HEAD': return UJSONResponse({'text': ''}, headers=response_header) text = gpt2.generate(sess, length=int(params.get('length', 1023)), temperature=float(params.get('temperature', 0.7)), top_k=int(params.get('top_k', 0)), top_p=float(params.get('top_p', 0)), prefix=params.get('prefix', '')[:500], truncate=params.get('truncate', None), include_prefix=str(params.get( 'include_prefix', True)).lower() == 'true', return_as_list=True)[0] generate_count += 1 if generate_count == 8: # Reload model to prevent Graph/Session from going OOM tf.reset_default_graph() sess.close() sess = gpt2.start_tf_sess(threads=1) gpt2.load_gpt2(sess) generate_count = 0 gc.collect() return UJSONResponse({'text': text}, headers=response_header)
def generate_ideas(): prefix = request.args.get("prefix") length = int(request.args.get("length", 50)) samples = int(request.args.get("samples", 1)) if samples <= 0 or samples > 5: abort( jsonify( { "message": "Samples value is invalid, min 1 and max 5 allowed." }, 400, )) session = gpt2.start_tf_sess() gpt2.load_gpt2(session, model_name=model_name) ideas = gpt2.generate( session, model_name=model_name, prefix=prefix, length=length, nsamples=samples, batch_size=samples, ) session.close() return jsonify(ideas=ideas)
def __init__(self, group=None, target=None, name=None, args=(), kwargs=None, verbose=None): super(GeneratorThread,self).__init__() self.target = target self.name = name self.last_model = '' self.sess = gpt2.start_tf_sess() return
def load_model(run): # Get our pathfor the checkpoint setup checkpoint_dir = Path("checkpoint").absolute() # Start tensorflow session & load our model sess = gpt2.start_tf_sess() gpt2.load_gpt2(sess, checkpoint_dir=checkpoint_dir, run_name=run) return sess
def loader(): # if 'sess' not in cache: # cache['sess'] = gpt2.start_tf_sess() # gpt2.load_gpt2(cache['sess'], checkpoint_dir='./gpt_2/checkpoint', run_name='run1') # return cache['sess'] sess = gpt2.start_tf_sess() gpt2.load_gpt2(sess, checkpoint_dir='../assets/gpt_2/checkpoint', run_name='run1') return sess
def main(): sess = gpt2.start_tf_sess() gpt2.load_gpt2(sess) single_text = gpt2.generate(sess, return_as_list=True)[0] print(single_text)