def main(): args = parser.parse_args() enc = encoder.get_encoder(args.model_name) hparams = model.default_hparams() hparams.res_dropout = args.dropout hparams.attn_dropout = args.dropout epsilon = -1e10 if args.dtype == 'float32': hparams.dtype = tf.float32 elif args.dtype == 'float16': hparams.dtype = tf.float16 epsilon = -65500 elif args.dtype == 'bfloat16': hparams.dtype = tf.bfloat16 epsilon = -65500 else: print('Unknown dtype', args.dtype) if args.float16: hparams.dtype = tf.bfloat16 epsilon = -65500 with open(os.path.join('models', args.model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if args.n_ctx >= 0: hparams.n_ctx=args.n_ctx if args.n_embd >= 0: hparams.n_embd=args.n_embd if args.n_head >= 0: hparams.n_head=args.n_head if args.n_layer >= 0: hparams.n_layer=args.n_layer if args.sample_length < 0: args.sample_length = hparams.n_ctx - 1 if args.sample_length > hparams.n_ctx: raise ValueError( "Can't get samples longer than window size: %s" % hparams.n_ctx) if args.sample_ctx < 0: args.sample_ctx = hparams.n_ctx if args.model_name == '345M': args.memory_saving_gradients = True if args.optimizer == 'adam': args.only_train_transformer_layers = True config = tf.ConfigProto() if args.allow_growth: config.gpu_options.allow_growth = True if args.disable_layout_optimizer: config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF with tflex.Session(config=config, init_tpu=args.init_tpu) as sess: context = tf.placeholder(tf.int32, [args.batch_size, None]) context_in = randomize(context, hparams, args.noise) output = model.model(hparams=hparams, X=context_in) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) if args.val_every > 0: val_context = tf.placeholder(tf.int32, [args.val_batch_size, None]) val_output = model.model(hparams=hparams, X=val_context) val_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) val_loss_summary = tf.summary.scalar('val_loss', val_loss) tf_sample = sample.sample_sequence( hparams=hparams, length=args.sample_length, context=context, batch_size=args.batch_size, temperature=1.0, top_k=args.top_k, top_p=args.top_p, epsilon=epsilon) all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars parameter_count = sum([np.prod(v.shape.as_list()) for v in train_vars]) print("This model is using %d parameters (%.2fM)" % (parameter_count, parameter_count/(1024.0*1024.0))) with tf.variable_scope(tf.get_variable_scope().name, reuse=tf.AUTO_REUSE): global_step = tflex.get_variable('global_step') or tf.get_variable('global_step', shape=(), dtype=tf.int32, trainable=False) current_step = args.learning_rate_initial_step global_step.load(current_step, session=sess) if args.learning_rate_cos: lr = tflex_sgdr.sgdr_decay_with_warmup(args.learning_rate, global_step, warmup_steps=args.learning_rate_warmup, initial_period_steps=args.learning_rate_period, learning_rate_min=args.learning_rate_min) else: lr = tflex.get_variable('learn_rate') or tf.get_variable('learn_rate', shape=(), dtype=tf.float32, trainable=False) lr.load(args.learning_rate, session=sess) def update_lr(rate=None, step=None): if not args.learning_rate_cos: if step is None: step = global_step.eval(session=sess) if rate is None: rate = args.learning_rate if callable(rate): rate = rate(step) lr.load(rate, session=sess) return lr.eval(session=sess) @tflex.register_command def set_learning_rate(): print("Current learn rate: %0.8f" % update_lr()) print("New learn rate?") rate = input('') if not rate: print("Empty input; not changing anything.") else: try: rate = float(rate) except: print("Invalid input; must be a float") print("Setting learn rate to %0.8f" % rate) args.learning_rate = rate if args.optimizer == 'adam': opt = tf.train.AdamOptimizer(learning_rate=lr) elif args.optimizer == 'sgd': opt = tf.train.GradientDescentOptimizer(learning_rate=lr) elif args.optimizer == 'ada': import tensor2tensor.utils.optimize from tensor2tensor.utils import hparam import tensor2tensor.models.research from tensor2tensor.utils import registry ada_hparams = registry.hparams('afx_mimic_adam') ada_hparams.optimizer_adafactor_beta1 = 0.0 ada_hparams.optimizer_adafactor_factored = True opt = tensor2tensor.utils.optimize.adafactor(learning_rate=lr, hparams=ada_hparams) else: exit('Bad optimizer:', args.optimizer) #if tpu_addr: # # https://pulsejet.github.io/blog/posts/tpu-without-estimator/ # from tensorflow.contrib.tpu.python.tpu import tpu_function # tpu_function.get_tpu_context().set_number_of_shards(8) # opt = tf.contrib.tpu.CrossShardOptimizer(opt) if args.accumulate_gradients > 1: if args.memory_saving_gradients: exit("Memory saving gradients are not implemented for gradient accumulation yet.") opt = AccumulatingOptimizer( opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: if args.memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_lr = tf.summary.scalar('learning_rate', lr) summaries = tf.summary.merge([summary_lr, summary_loss]) summary_log = tf.summary.FileWriter( os.path.join(CHECKPOINT_DIR, args.run_name)) if args.save_graph: summary_log.add_graph(tf.get_default_graph()) saver = tflex.Saver( var_list=all_vars, max_to_keep=args.max_to_keep, keep_checkpoint_every_n_hours=100000, reshape=args.truncate_weights) sess.run(tf.global_variables_initializer()) if args.restore_from == 'latest': ckpt = tflex.latest_checkpoint( os.path.join(CHECKPOINT_DIR, args.run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tflex.latest_checkpoint( os.path.join('models', args.model_name)) elif args.restore_from == 'fresh': ckpt = tflex.latest_checkpoint( os.path.join('models', args.model_name)) else: ckpt = tflex.latest_checkpoint(args.restore_from) print('Loading snapshot %s...' % ckpt) t0 = time.time() if not args.fresh_model: saver.restore(sess, ckpt) t1 = time.time() print('Loaded in %f seconds' % (t1 - t0)) def make_sampler(dataset, enc, seed, combine): if os.path.isdir(dataset) or dataset.endswith('.npz'): chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks, seed=seed) print('dataset has', data_sampler.total_size, 'tokens', len(chunks), 'chunks') else: data_sampler = TextSampler(dataset, enc, seed=seed) return data_sampler print('Loading dataset...') seed = None if args.seed < 0 else args.seed data_sampler = make_sampler(dataset=args.dataset, enc=enc, seed=seed, combine=args.combine) if args.val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_dataset = args.val_dataset if args.val_dataset else args.dataset val_data_sampler = make_sampler(dataset=val_dataset, enc=enc, seed=1, combine=args.combine) val_batches = [[val_data_sampler.sample(hparams.n_ctx) for _ in range(args.val_batch_size)] for _ in range(args.val_batch_count)] print('Training...') counter = 1 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 @tflex.register_command def get_tarfile_name(checkpoint_folder): """Converts a folder path into a filename for a .tar archive""" tarfile_name = checkpoint_folder.replace(os.path.sep, '_') + '.tar' return tarfile_name def copy_checkpoint_to_gdrive(run_name='run1', copy_folder=False): """Copies the checkpoint folder to a mounted Google Drive.""" #is_mounted() checkpoint_folder = os.path.join('checkpoint', run_name) if copy_folder: shutil.copytree(checkpoint_folder, "/content/drive/My Drive/" + checkpoint_folder) else: file_path = get_tarfile_name(checkpoint_folder) # Reference: https://stackoverflow.com/a/17081026 with tarfile.open(file_path, 'w') as tar: tar.add(checkpoint_folder) shutil.copyfile(file_path, "/content/drive/My Drive/" + file_path) @tflex.register_command def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) t0 = time.time() saver.save( sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) t1 = time.time() print('Saved in %f seconds' % (t1 - t0)) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') #copy_checkpoint_to_gdrive() @tflex.register_command def generate_samples(): print('Generating samples...') context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < args.sample_num: out = sess.run( tf_sample, feed_dict={context: args.batch_size * [context_tokens]}) for i in range(min(args.sample_num - index, args.batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) print(text) all_text.append(text) index += 1 maketree(os.path.join(SAMPLE_DIR, args.run_name)) with open( os.path.join(SAMPLE_DIR, args.run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) @tflex.register_command def validation(): if args.val_every <= 0: return print('Calculating validation loss...') losses = [] for batch in tqdm.tqdm(val_batches): losses.append(sess.run(val_loss, feed_dict={val_context: batch})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print( '{stamp} [{counter} | {time:2.4f}] validation loss = {loss:2.4f}' .format( stamp=timestamp(), counter=counter, time=time.time() - start_time, loss=v_val_loss)) start_time = time.time() def elapsed(): return time.time() - start_time def say(msg): print('{stamp} [{counter} | {time:2.4f}] {msg}'.format(counter=counter, time=elapsed(), msg=msg, stamp=timestamp())) def sample_batch(): #return [data_sampler.sample(args.sample_ctx) for _ in range(args.batch_size)] #say('Sampling batch...') r = [] times = [] for _ in range(args.batch_size): start = time.time() sample = data_sampler.sample(args.sample_ctx) end = time.time() elapsed = (end - start) r += [sample] times += [elapsed] total = sum(times) avg = total / len(times) #say('Sampled %d batches in %.4f seconds (avg per batch: %.4f)' % (args.batch_size, total, avg)) return r prev_time = time.time() avg_loss = (0.0, 0.0) if args.debug_before_training: import pdb pdb.set_trace() last_saved_time = elapsed() while True: try: now = elapsed() if args.save_time > 0 and (((now - last_saved_time) / 60.0) >= args.save_time): save() last_saved_time = now elif args.save_every > 0 and (counter % args.save_every == 0): save() if counter % args.sample_every == 0: generate_samples() if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1): validation() v_rate = update_lr() if args.accumulate_gradients > 1: #say('Running opt_reset...') sess.run(opt_reset) for _ in range(args.accumulate_gradients): batch = sample_batch() say('Running opt_compute...') sess.run(opt_compute, feed_dict={context: batch}) say('Running opt_apply...') (v_loss, v_summary) = sess.run((opt_apply, summaries)) else: batch = sample_batch() say('Running opt_apply...') (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summaries), feed_dict={context: batch}) if args.float16: v_loss = tf.to_float(v_loss).eval() summary_log.add_summary(v_summary, counter) summary_log.flush() avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) now = time.time() print('{stamp} [{counter} | {time:2.4f} | {delta:2.2f}s | {ops:2.6f}tokens/s] loss={loss:2.4f} avg={avg:2.4f} rate={rate:0.7f} step={step}' .format( stamp=timestamp(), counter=counter, time=now - start_time, delta=now - prev_time, ops=args.sample_ctx * args.batch_size / (now - prev_time), rate=v_rate, loss=v_loss, avg=avg_loss[0] / avg_loss[1], step=current_step, )) counter += 1 current_step += 1 global_step.load(current_step, session=sess) tflex.check_commands_with_args( session=sess, stamp=timestamp(), counter=counter, time=now - start_time, delta=now - prev_time, ops=args.batch_size / (now - prev_time), rate=v_rate, loss=v_loss, avg=avg_loss[0] / avg_loss[1], avg_loss=avg_loss, step=current_step, train_vars=train_vars, all_vars=all_vars, args=args, data_sampler=data_sampler, ckpt=ckpt, saver=saver, ) if tflex.should_quit(): break prev_time = now if args.debug_print_all_vars: print('all variables:') print('name/shape/parameter_count') param_count = 0 for x in tf.all_variables(): shape = x.shape.as_list() count = np.prod(shape) print(x.name, shape, count) param_count += count print('Total parameters:', param_count) args.debug_print_all_vars = False if args.debug_print_trainable_vars: print('trainable variables:') print('name/shape/parameter_count') param_count = 0 for x in tf.trainable_variables(): shape = x.shape.as_list() count = np.prod(shape) print(x.name, shape, count) param_count += count print('Total parameters:', param_count) args.debug_print_trainable_vars = False except KeyboardInterrupt: print('interrupted') if args.save_on_ctrlc: save() if args.debug_on_ctrlc: import pdb pdb.set_trace() else: break
def interact_model(model_name='117M', restore_from=None, seed=None, nsamples=1, step=1, length=64, prompt="\n", clear=None, maxlen=-1, temperature=1, top_k=0, top_p=0, penalize=0): """ Interactively run the model :model_name=117M : String, which model to use :seed=None : Integer seed for random number generators, fix seed to reproduce results :nsamples=1 : Number of samples to return total :step=1 : Number of tokens to generate at a time :length=64 : Window size; use 1024 for maximum size per sample :prompt="\\n" : Prompt to start with. The default of "" prompts with an <|endoftext|> token. :clear=None : If this string is encountered, clear the context window. :maxlen=-1 : if this many tokens are generated without encountering --clear, then print it and clear the context window. :temperature=1 : Float value controlling randomness in boltzmann distribution. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions. :top_k=0 : Integer value controlling diversity. 1 means only 1 word is considered for each step (token), resulting in deterministic completions, while 40 means 40 words are considered at each step. 0 (default) is a special setting meaning no restrictions. 40 generally is a good value. :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling, overriding top_k if set to a value > 0. A good setting is 0.9. :penalize=0.0 : Float value controlling "used" penalty. Implements repetition reduction (similar to CTRL) if set to a value > 0. A decent setting might be 0.85 with temperature 0.3 and top_k 40. """ batch_size = 1 assert nsamples % batch_size == 0 enc = encoder.get_encoder(model_name) hparams = model.default_hparams() with open(os.path.join('models', model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if length > hparams.n_ctx: raise ValueError("Length can't be largeer than n_ctx: %s" % hparams.n_ctx) if step > length: raise ValueError("Can't get samples longer than length: %s" % length) with tflex.Session(graph=tf.Graph()) as sess: context = tf.placeholder(tf.int32, [batch_size, None]) np.random.seed(seed) tf.set_random_seed(seed) output = sample.sample_sequence(hparams=hparams, length=step, context=context, batch_size=batch_size, temperature=temperature, top_k=top_k, top_p=top_p, penalize=penalize) saver = tflex.Saver(reshape=True) if restore_from is None: restore_from = os.path.join('models', model_name) ckpt = tflex.latest_checkpoint(restore_from) saver.restore(sess, ckpt) while True: tflex.check_commands() if tflex.should_quit(): break try: with open(prompt) as f: tflex.raw_text = f.read() if tflex.raw_text.endswith('\n'): tflex.raw_text = tflex.raw_text[:-1] if tflex.raw_text.endswith('\r'): tflex.raw_text = tflex.raw_text[:-1] except: tflex.raw_text = prompt tflex.raw_text = tflex.raw_text.replace('\\n', '\n') tflex.raw_text = tflex.raw_text.replace('\\t', '\t') #print(repr(tflex.raw_text)) tflex.context_tokens = enc.encode( tflex.raw_text) if len(tflex.raw_text) > 0 else [50256] while len(tflex.context_tokens) > length - step - 1: tflex.context_tokens = tflex.context_tokens[1:] tflex.prompt_tokens = tflex.context_tokens[:] tflex.first = True tflex.backlog = [] tflex.backlog_count = 0 tflex.context_text = "" tflex.context_count = 0 while True: for tokens in generate_result( context_tokens=tflex.context_tokens, enc=enc, output=output, context=context, nsamples=1, batch_size=batch_size, sess=sess): tflex.tokens = tokens if tflex.first: #clear_output(wait=True) sys.stdout.write(enc.decode(tflex.context_tokens)) sys.stdout.flush() tflex.first = False tflex.backlog.extend(tflex.tokens) tflex.backlog_count += 1 if is_ascii(enc.decode([tflex.backlog[-1] ])) or tflex.backlog_count > 16: text = enc.decode(tflex.backlog) result = text if clear is not None: result, *rest = text.split(clear) sys.stdout.write(result) sys.stdout.flush() tflex.context_text += text tflex.context_count += len(tflex.backlog) def reset_context(): tflex.context_text = "" tflex.context_count = 0 tflex.context_tokens = [] tflex.first = True tflex.tokens = tflex.prompt_tokens[:] tflex.reset_context = reset_context if maxlen > 0 and tflex.context_count > maxlen or clear is not None and clear in tflex.context_text: tflex.reset_context() tflex.backlog = [] tflex.backlog_count = 0 tflex.check_commands() tflex.context_tokens.extend(tflex.tokens) while len(tflex.context_tokens) > length - step - 1: tflex.context_tokens = tflex.context_tokens[1:]
def main(): """Run the MODEL interactively.""" print("\nWelcome to COVID-19 chatbot!") print("The input prompt will appear shortly\n\n") models_dir = os.path.expanduser(os.path.expandvars(MODELS_DIR)) assert NSAMPLES % BATCH_SIZE == 0 enc = encoder.get_encoder(MODEL_NAME) hparams = model.default_hparams() with open(os.path.join(models_dir, MODEL_NAME, "hparams.json")) as file: hparams.override_from_dict(json.load(file)) if LENGTH is None: length = hparams.n_ctx // 2 elif LENGTH > hparams.n_ctx: raise ValueError( "Can't get samples longer than window size: {}".format( hparams.n_ctx)) with tf.Session(graph=tf.Graph()) as sess: context = tf.placeholder(tf.int32, [BATCH_SIZE, None]) np.random.seed(SEED) tf.set_random_seed(SEED) output = sample.sample_sequence( hparams=hparams, length=length, context=context, batch_size=BATCH_SIZE, temperature=TEMPERATURE, top_k=TOP_K, ) saver = tflex.Saver() saver.restore(sess, CHECKPOINT) while True: question = input("COVID-19 CHATBOT> ") while not question: print("Prompt should not be empty!") question = input("COVID-19 CHATBOT> ") context_tokens = [enc.encode(question)] * BATCH_SIZE # custom for full length text total_tokens = len(context_tokens[0]) generated_once = False gen_texts = [] answers = "" split_length = int(1023 * SPLIT_CONTEXT) split_output_length = min(length, 1023 - split_length) for _ in range(NSAMPLES // BATCH_SIZE): gen_text = [np.array([])] * BATCH_SIZE truncated = [False] * BATCH_SIZE while False in truncated: num_tokens = 1023 - (len(context_tokens[0])) if generated_once: new_split_output_length = min(length - total_tokens, 1023 - split_length) if new_split_output_length != split_output_length: split_output = sample.sample_sequence( hparams=hparams, length=new_split_output_length, start_token=enc.encoder['<|endoftext|>'] if not question else None, context=context if question else None, batch_size=BATCH_SIZE, temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P)[:, 1:] out = sess.run(split_output, feed_dict={context: context_tokens}) else: out = sess.run(output, feed_dict={context: context_tokens}) total_tokens += num_tokens for i in range(BATCH_SIZE): text = out[i] trunc_text = "" if question: text = np.append(context_tokens[i][:1], text) if TRUNCATE or all(gen_text): context_tokens[i] = out[i][(1023 - split_length - 1):] if generated_once: text = out[i][split_length:] if TRUNCATE: to_trunc = enc.decode(text) truncate_esc = re.escape(TRUNCATE) if question and not include_prefix: prefix_esc = re.escape(question) pattern = '(?:{})(.*?)(?:{})'.format( prefix_esc, truncate_esc) else: pattern = '(.*?)(?:{})'.format( truncate_esc) trunc_text = re.search(pattern, to_trunc, re.S) if trunc_text: text = enc.encode(trunc_text.group(1)) # better to re-encode here then decode every generation cycle, I think if not truncated[i]: gen_text[i] = np.concatenate((gen_text[i], text), axis=None) if trunc_text or (length is not None and total_tokens >= length - 1): truncated[i] = True gen = enc.decode(gen_text[i]).lstrip('\n') ''' if destination_path: f.write("{}\n{}".format(gen, sample_delim)) if not return_as_list and not destination_path: print("{}\n{}".format(gen, sample_delim), end='') ''' answers += gen generated_once = True answers = "" for idx in range(BATCH_SIZE): answers += enc.decode(out[idx]) # Process the string (cleanup) clean_answers = cleaner.clean_additional(" ".join( cleaner.clean_text(answers))) final_answers = cleaner.chunk_into_sentences(clean_answers) try: #print(similarity.use_filter(question, answers, 5)) print(answers) except Exception: print(" ".join(answers)) print("WARNING: Model cannot generate an answer using USE") ''' for _ in range(NSAMPLES // BATCH_SIZE): out = sess.run( output, feed_dict={ context: [context_tokens for _ in range(BATCH_SIZE)] }, )[:, len(context_tokens) :] # Build the answers string answers = "" for idx in range(BATCH_SIZE): answers += enc.decode(out[idx]) # Process the string (cleanup) clean_answers = cleaner.clean_additional( " ".join(cleaner.clean_text(answers)) ) final_answers = cleaner.chunk_into_sentences(clean_answers) try: #print(similarity.use_filter(question, answers, 5)) print(answers) except Exception: print(" ".join(answers)) print("WARNING: Model cannot generate an answer using USE") ''' print() print("=" * 79) print()
top_p = 1 enc = encoder.get_encoder(model_name) hparams = model.default_hparams() with open(os.path.join('models', model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) sess = tf.Session(graph=tf.Graph()).__enter__() context = tf.placeholder(tf.int32, [4, None]) np.random.seed(None) tf.set_random_seed(None) output = sample.sample_sequence( hparams=hparams, length=length, context=context, batch_size=batch_size, temperature=temperature, top_k=top_k, top_p=top_p ) saver = tflex.Saver() if restore_from is None: restore_from = os.path.join('models', model_name) ckpt = tflex.latest_checkpoint(restore_from) saver.restore(sess, ckpt) class Article(BaseModel): text: str @app.post("/text_rewrite") def text_rewrite( article: Article ): text = article.text
def interact_model(model_name='117M', restore_from=None, seed=None, nsamples=1, batch_size=1, length=None, temperature=1, top_k=0, top_p=0.0, penalize=0, prompt=None): """ Interactively run the model :model_name=117M : String, which model to use :seed=None : Integer seed for random number generators, fix seed to reproduce results :nsamples=1 : Number of samples to return total :batch_size=1 : Number of batches (only affects speed/memory). Must divide nsamples. :length=None : Number of tokens in generated text, if None (default), is determined by model hyperparameters :temperature=1 : Float value controlling randomness in boltzmann distribution. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions. :top_k=0 : Integer value controlling diversity. 1 means only 1 word is considered for each step (token), resulting in deterministic completions, while 40 means 40 words are considered at each step. 0 (default) is a special setting meaning no restrictions. 40 generally is a good value. :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling, overriding top_k if set to a value > 0. A good setting is 0.9. :penalize=0.0 : Float value controlling "used" penalty. Implements repetition reduction (similar to CTRL) if set to a value > 0. A decent setting might be 0.85 with temperature 0.3 and top_k 40. """ if batch_size is None: batch_size = 1 assert nsamples % batch_size == 0 enc = encoder.get_encoder(model_name) hparams = model.default_hparams() with open(os.path.join('models', model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if length is None: length = hparams.n_ctx // 2 elif length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) with tflex.Session(graph=tf.Graph()) as sess: context = tf.placeholder(tf.int32, [batch_size, None]) np.random.seed(seed) tf.set_random_seed(seed) output = sample.sample_sequence(hparams=hparams, length=length, context=context, batch_size=batch_size, temperature=temperature, top_k=top_k, top_p=top_p, penalize=penalize) saver = tflex.Saver() if restore_from is None: restore_from = os.path.join('models', model_name) ckpt = tflex.latest_checkpoint(restore_from) saver.restore(sess, ckpt) while True: if prompt is not None: if os.path.isfile(prompt): with open(prompt) as f: raw_text = f.read() else: raw_text = prompt else: raw_text = input("Model prompt >>> ") if not raw_text: raw_text = "\n" if len(raw_text) > 1 and raw_text.endswith('\n'): raw_text = raw_text[:-1] print('Prompt:', repr(raw_text)) context_tokens = enc.encode(raw_text) generated = 0 for _ in range(nsamples // batch_size): out = sess.run(output, feed_dict={ context: [context_tokens for _ in range(batch_size)] })[:, len(context_tokens):] for i in range(batch_size): generated += 1 text = enc.decode(out[i]) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) sys.stdout.write(raw_text) print(text) sys.stdout.flush() print("=" * 80)
def main(): """Run the MODEL interactively.""" print("\nWelcome to COVID-19 chatbot!") print("The input prompt will appear shortly\n\n") models_dir = os.path.expanduser(os.path.expandvars(MODELS_DIR)) assert NSAMPLES % BATCH_SIZE == 0 enc = encoder.get_encoder(MODEL_NAME) hparams = model.default_hparams() with open(os.path.join(models_dir, MODEL_NAME, "hparams.json")) as file: hparams.override_from_dict(json.load(file)) if LENGTH is None: length = hparams.n_ctx // 2 elif LENGTH > hparams.n_ctx: raise ValueError( "Can't get samples longer than window size: {}".format( hparams.n_ctx)) with tf.Session(graph=tf.Graph()) as sess: context = tf.placeholder(tf.int32, [BATCH_SIZE, None]) np.random.seed(SEED) tf.set_random_seed(SEED) output = sample.sample_sequence( hparams=hparams, length=length, context=context, batch_size=BATCH_SIZE, temperature=TEMPERATURE, top_k=TOP_K, ) saver = tflex.Saver() saver.restore(sess, CHECKPOINT) while True: question = input("COVID-19 CHATBOT> ") while not question: print("Prompt should not be empty!") question = input("COVID-19 CHATBOT> ") context_tokens = enc.encode(question) for _ in range(NSAMPLES // BATCH_SIZE): out = sess.run( output, feed_dict={ context: [context_tokens for _ in range(BATCH_SIZE)] }, )[:, len(context_tokens):] # Build the answers string answers = "" for idx in range(BATCH_SIZE): answers += enc.decode(out[idx]) # Process the string (cleanup) clean_answers = cleaner.clean_additional(" ".join( cleaner.clean_text(answers))) final_answers = cleaner.chunk_into_sentences(clean_answers) try: print(similarity.use_filter(question, final_answers, 5)) except Exception: print(" ".join(final_answers)) print("WARNING: Model cannot generate an answer using USE") print() print("=" * 79) print()
def text_rewrite( input_filepath, output_filepath, model_name='1558M', restore_from=None, seed=None, batch_size=4, length=70, temperature=1, top_k=0, top_p=1, init_tpu=False ): nlp = English() nlp.add_pipe(nlp.create_pipe('sentencizer')) enc = encoder.get_encoder(model_name) hparams = model.default_hparams() with open(os.path.join('models', model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) sess_initer = tflex if init_tpu else tf with sess_initer.Session(graph=tf.Graph()) as sess: context = tf.placeholder(tf.int32, [batch_size, None]) np.random.seed(seed) tf.set_random_seed(seed) output = sample.sample_sequence( hparams=hparams, length=length, context=context, batch_size=batch_size, temperature=temperature, top_k=top_k, top_p=top_p ) saver = tflex.Saver() if restore_from is None: restore_from = os.path.join('models', model_name) ckpt = tflex.latest_checkpoint(restore_from) saver.restore(sess, ckpt) with open(input_filepath, encoding='utf-8') as f: text = f.read() result = [] for sent in nlp(text).sents: sent_text = sent.text.strip() if not sent_text: continue model_input = "ORIGINAL_SENT: {} >>>>>".format(sent_text) context_tokens = enc.encode(model_input) print(model_input) print('*' * 80) out = sess.run(output, feed_dict={ context: [context_tokens for _ in range(batch_size)] })[:, len(context_tokens):] texts = [] for i in range(batch_size): text = enc.decode(out[i]) + "\n" print("SAMPLE {}".format(i)) print(text) print('-' * 80) texts.append(text) rephrased = get_best_candidate(sent_text, texts) rephrased = add_new_line(sent.text, rephrased) result.append(rephrased) print("REPHRASED: {}".format(rephrased)) print('=' * 80) with open(output_filepath, encoding='utf-8', mode='w') as f: f.write(" ".join(result))
def interact_model(model_name='117M', restore_from=None, seed=None, nsamples=1, step=1, length=64, prompt="\n", clear=None, maxlen=-1, temperature=1, top_k=0, top_p=0, penalize=0): """ Interactively run the model :model_name=117M : String, which model to use :seed=None : Integer seed for random number generators, fix seed to reproduce results :nsamples=1 : Number of samples to return total :step=1 : Number of tokens to generate at a time :length=64 : Window size; use 1024 for maximum size per sample :prompt="\\n" : Prompt to start with. The default of "" prompts with an <|endoftext|> token. :clear=None : If this string is encountered, clear the context window. :maxlen=-1 : if this many tokens are generated without encountering --clear, then print it and clear the context window. :temperature=1 : Float value controlling randomness in boltzmann distribution. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions. :top_k=0 : Integer value controlling diversity. 1 means only 1 word is considered for each step (token), resulting in deterministic completions, while 40 means 40 words are considered at each step. 0 (default) is a special setting meaning no restrictions. 40 generally is a good value. :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling, overriding top_k if set to a value > 0. A good setting is 0.9. :penalize=0.0 : Float value controlling "used" penalty. Implements repetition reduction (similar to CTRL) if set to a value > 0. A decent setting might be 0.85 with temperature 0.3 and top_k 40. """ batch_size = 1 assert nsamples % batch_size == 0 enc = encoder.get_encoder(model_name) hparams = model.default_hparams() with open(os.path.join('models', model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if length > hparams.n_ctx: raise ValueError("Length can't be largeer than n_ctx: %s" % hparams.n_ctx) if step > length: raise ValueError("Can't get samples longer than length: %s" % length) with tflex.Session(graph=tf.Graph()) as sess: context = tf.placeholder(tf.int32, [batch_size, None]) np.random.seed(seed) tf.set_random_seed(seed) output = sample.sample_sequence(hparams=hparams, length=step, context=context, batch_size=batch_size, temperature=temperature, top_k=top_k, top_p=top_p, penalize=penalize) saver = tflex.Saver(reshape=True) if restore_from is None: restore_from = os.path.join('models', model_name) ckpt = tflex.latest_checkpoint(restore_from) saver.restore(sess, ckpt) saver2 = tf.train.Saver() counter = int(ckpt.split('-')[-1].split('.')[0]) saver2.save(sess, os.path.join('saved', 'model'), global_step=counter)
def chatbot_response(question: str) -> str: """Respond to a question.""" models_dir = os.path.expanduser(os.path.expandvars(MODELS_DIR)) assert NSAMPLES % BATCH_SIZE == 0 enc = encoder.get_encoder(MODEL_NAME, dirback=True) hparams = model.default_hparams() with open(os.path.join(models_dir, MODEL_NAME, "hparams.json")) as file: hparams.override_from_dict(json.load(file)) if LENGTH is None: length = hparams.n_ctx // 2 elif LENGTH > hparams.n_ctx: raise ValueError( "Can't get samples longer than window size: {}".format( hparams.n_ctx)) with tf.Session(graph=tf.Graph()) as sess: context = tf.placeholder(tf.int32, [BATCH_SIZE, None]) np.random.seed(SEED) tf.set_random_seed(SEED) output = sample.sample_sequence( hparams=hparams, length=length, context=context, batch_size=BATCH_SIZE, temperature=TEMPERATURE, top_k=TOP_K, ) saver = tflex.Saver() saver.restore(sess, CHECKPOINT) context_tokens = enc.encode(question) response: str = "" for _ in range(NSAMPLES // BATCH_SIZE): out = sess.run( output, feed_dict={ context: [context_tokens for _ in range(BATCH_SIZE)] }, )[:, len(context_tokens):] # Build the answers string answers = "" for idx in range(BATCH_SIZE): answers += enc.decode(out[idx]) # Process the string (cleanup) clean_answers = cleaner.clean_additional(" ".join( cleaner.clean_text(answers))) final_answers = cleaner.chunk_into_sentences(clean_answers) try: response += similarity.use_filter(question, final_answers, 5) except Exception: response += " ".join(final_answers) return response
def sample_model(model_name='117M', restore_from=None, seed=None, nsamples=0, batch_size=1, length=None, temperature=1, top_k=0, top_p=0.0, penalize=0): """ Run the sample_model :model_name=117M : String, which model to use :seed=None : Integer seed for random number generators, fix seed to reproduce results :nsamples=0 : Number of samples to return, if 0, continues to generate samples indefinately. :batch_size=1 : Number of batches (only affects speed/memory). :length=None : Number of tokens in generated text, if None (default), is determined by model hyperparameters :temperature=1 : Float value controlling randomness in boltzmann distribution. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions. :top_k=0 : Integer value controlling diversity. 1 means only 1 word is considered for each step (token), resulting in deterministic completions, while 40 means 40 words are considered at each step. 0 (default) is a special setting meaning no restrictions. 40 generally is a good value. :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling, overriding top_k if set to a value > 0. A good setting is 0.9. :penalize=0.0 : Float value controlling "used" penalty. Implements repetition reduction (similar to CTRL) if set to a value > 0. A decent setting might be 0.85 with temperature 0.3 and top_k 40. """ enc = encoder.get_encoder(model_name) hparams = model.default_hparams() with open(os.path.join('models', model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if length is None: length = hparams.n_ctx elif length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) with tflex.Session(graph=tf.Graph()) as sess: np.random.seed(seed) tf.set_random_seed(seed) output = sample.sample_sequence( hparams=hparams, length=length, start_token=enc.encoder['<|endoftext|>'], batch_size=batch_size, temperature=temperature, top_k=top_k, top_p=top_p, penalize=penalize)[:, 1:] saver = tflex.Saver() if restore_from is None: restore_from = os.path.join('models', model_name) ckpt = tflex.latest_checkpoint(restore_from) saver.restore(sess, ckpt) generated = 0 while nsamples == 0 or generated < nsamples: out = sess.run(output) for i in range(batch_size): generated += 1 text = enc.decode(out[i]) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print(text)
def interact_model(model_name='117M', asker=None, responder=None, restore_from=None, seed=None, length=None, temperature=1, top_k=0, top_p=0.0, penalize=0, prompt=None): """ Interactively chat with the model :model_name=117M : String, which model to use :seed=None : Integer seed for random number generators, fix seed to reproduce results :length=None : Number of tokens in generated text, if None (default), is determined by model hyperparameters :temperature=1 : Float value controlling randomness in boltzmann distribution. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions. :top_k=0 : Integer value controlling diversity. 1 means only 1 word is considered for each step (token), resulting in deterministic completions, while 40 means 40 words are considered at each step. 0 (default) is a special setting meaning no restrictions. 40 generally is a good value. :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling, overriding top_k if set to a value > 0. A good setting is 0.9. :penalize=0.0 : Float value controlling "used" penalty. Implements repetition reduction (similar to CTRL) if set to a value > 0. A decent setting might be 0.85 with temperature 0.3 and top_k 40. """ if asker is None: raise Exception( "Add a name present in the training dataset that you will be chatting as" ) if responder is None: raise Exception( "Add a name present in the training dataset that gpt will be chatting as" ) enc = encoder.get_encoder(model_name) hparams = model.default_hparams() with open(os.path.join('models', model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if length is None: length = hparams.n_ctx // 2 elif length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) with tflex.Session(graph=tf.Graph()) as sess: context = tf.placeholder(tf.int32, [1, None]) np.random.seed(seed) tf.set_random_seed(seed) output = sample.sample_sequence(hparams=hparams, length=length, context=context, batch_size=1, temperature=temperature, top_k=top_k, top_p=top_p, penalize=penalize) saver = tflex.Saver() if restore_from is None: restore_from = os.path.join('models', model_name) ckpt = tflex.latest_checkpoint(restore_from) saver.restore(sess, ckpt) input_ = '' time = 1924862493344 while True: time = increase_time(time) input_ = input_ + f'({time}) {asker}: ' + input(f"{asker}: ") time = increase_time(time) input_ = input_ + f'\n ({time}) {responder}: ' if len(input_) > 1 and input_.endswith('\n'): input_ = input_[:-1] context_tokens = enc.encode(input_) out = sess.run(output, feed_dict={context: [context_tokens]})[:, len(context_tokens):] enc.decode(out[0]) text = enc.decode(out[0]).split(f') {asker}', 1)[0] print(f'\n ({time}) {responder}: ' + text.rsplit('(', 1)[0]) input_ = input_ + text sys.stdout.flush()