def generate(model, inputs): text = inputs['prompt'] encoded = _tokenize_article_pieces(encoder, text) context_formatted = [] context_formatted.extend(encoded[:-1]) ignore_ids_np = np.array(encoder.special_tokens_onehot) ignore_ids_np[encoder.endoftext] = 0 gens = [] gens_raw = [] gen_probs = [] for chunk_i in range(num_chunks): tokens_out, probs_out = sess.run( [model['tokens'], model['probs']], feed_dict={ model['initial_context']: [context_formatted] * batch_size_per_chunk, model['eos_token']: 60000, model['p_for_topp']: top_p[chunk_i] }) for t_i, p_i in zip(tokens_out, probs_out): extraction = extract_generated_target(output_tokens=t_i, encoder=encoder, target=args.target) gens.append(extraction['extraction']) return gens[0]
def generate_article_attribute(sess, encoder, tokens, probs, article, target='article'): """ Given attributes about an article (title, author, etc), use that context to generate a replacement for one of those attributes using the Grover model. This function is based on the Grover examples distributed with the Grover code. """ # Tokenize the raw article text article_pieces = _tokenize_article_pieces(encoder, article) # Grab the article elements the model careas about - domain, date, title, etc. context_formatted = [] for key in ['domain', 'date', 'authors', 'title', 'article']: if key != target: context_formatted.extend(article_pieces.pop(key, [])) # Start formatting the tokens in the way the model expects them, starting with # which article attribute we want to generate. context_formatted.append(encoder.__dict__['begin_{}'.format(target)]) # Tell the model which special tokens (such as the end token) aren't part of the text ignore_ids_np = np.array(encoder.special_tokens_onehot) ignore_ids_np[encoder.__dict__['end_{}'.format(target)]] = 0 # We are only going to generate one article attribute with a fixed # top_ps cut-off of 95%. This simple example isn't processing in batches. gens = [] article['top_ps'] = [0.95] # Run the input through the TensorFlow model and grab the generated output tokens_out, probs_out = sess.run( [tokens, probs], feed_dict={ # Pass real values for the inputs that the # model needs to be able to run. initial_context: [context_formatted], eos_token: encoder.__dict__['end_{}'.format(target)], ignore_ids: ignore_ids_np, p_for_topp: np.array([0.95]), } ) # The model is done! Grab the results it generated and format the results into normal text. for t_i, p_i in zip(tokens_out, probs_out): extraction = extract_generated_target(output_tokens=t_i, encoder=encoder, target=target) gens.append(extraction['extraction']) # Return the generated text. return gens[-1]
initial_context = tf.placeholder(tf.int32, [batch_size_per_chunk, None]) p_for_topp = tf.placeholder(tf.float32, [batch_size_per_chunk]) eos_token = tf.placeholder(tf.int32, []) tokens, probs = sample(news_config=news_config, initial_context=initial_context, eos_token=eos_token, ignore_ids=None, p_for_topp=p_for_topp, do_topk=False) saver = tf.train.Saver() saver.restore(sess, args.model_ckpt) print('Loaded model.') text = input() while text != "": for i in range(args.samples): print("Sample,", i + 1, " of ", args.samples) # Let's go! encoded = _tokenize_article_pieces(encoder, text) context_formatted = [] # for key in ['domain', 'date', 'authors', 'title', 'article']: # if key != args.target: # context_formatted.extend(article_pieces.pop(key, [])) # context_formatted.append(encoder.__dict__['begin_{}'.format(args.target)]) context_formatted.extend(encoded[:-1]) # Format context end # Indices we definitely DONT WANT TO PREDICT ignore_ids_np = np.array(encoder.special_tokens_onehot) ignore_ids_np[encoder.endoftext] = 0 gens = [] gens_raw = [] gen_probs = []
p_for_topp = tf.placeholder(tf.float32, [batch_size_per_chunk]) eos_token = tf.placeholder(tf.int32, []) ignore_ids = tf.placeholder(tf.bool, [news_config.vocab_size]) tokens, probs = sample(news_config=news_config, initial_context=initial_context, eos_token=eos_token, ignore_ids=ignore_ids, p_for_topp=p_for_topp, do_topk=False) saver = tf.train.Saver() saver.restore(sess, args.model_ckpt) # Let's go! for i, article in enumerate(tqdm(articles)): article_pieces = _tokenize_article_pieces(encoder, article) context_formatted = [] for key in ['domain', 'date', 'authors', 'title', 'article']: if key != args.target: context_formatted.extend(article_pieces.pop(key, [])) context_formatted.append(encoder.__dict__['begin_{}'.format( args.target)]) # Format context end # Indices we definitely DONT WANT TO PREDICT ignore_ids_np = np.array(encoder.special_tokens_onehot) ignore_ids_np[encoder.__dict__['end_{}'.format(args.target)]] = 0 gens = [] gens_raw = [] gen_probs = []