def encode_dataset(file_path, model_dir='models', out_path='text_encoded.npz', model_name="124M", combine=50000): """Preencodes a text document into chunks and compresses it, saving time when generated. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/encode.py """ model_path = os.path.join(model_dir, model_name) enc = encoder.get_encoder(model_path) print('Reading files') chunks = load_dataset(enc, file_path, combine) print('Writing', out_path) np.savez_compressed(out_path, *chunks)
def generate(sess, run_name='run1', checkpoint_dir='checkpoint', model_name=None, model_dir='models', sample_dir='samples', return_as_list=False, truncate=None, destination_path=None, sample_delim='=' * 20 + '\n', prefix=None, seed=None, nsamples=1, batch_size=1, length=1023, temperature=0.7, top_k=0, top_p=0.0, include_prefix=True): """Generates text from a model loaded into memory. Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py """ if batch_size is None: batch_size = 1 assert nsamples % batch_size == 0 if nsamples == 1: sample_delim = '' if prefix == '': prefix = None if model_name: checkpoint_path = os.path.join(model_dir, model_name) else: checkpoint_path = os.path.join(checkpoint_dir, run_name) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if prefix: context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) context_tokens = enc.encode(prefix) np.random.seed(seed) tf.compat.v1.set_random_seed(seed) output = sample.sample_sequence( hparams=hparams, length=min(length, 1023 - (len(context_tokens) if prefix else 0)), start_token=enc.encoder['<|endoftext|>'] if not prefix else None, context=context if prefix else None, batch_size=batch_size, temperature=temperature, top_k=top_k, top_p=top_p )[:, 1:] if destination_path: f = open(destination_path, 'w') generated = 0 gen_texts = [] while generated < nsamples: if not prefix: out = sess.run(output) else: out = sess.run(output, feed_dict={ context: batch_size * [context_tokens] }) for i in range(batch_size): generated += 1 gen_text = enc.decode(out[i]) if prefix: gen_text = enc.decode(context_tokens[:1]) + gen_text if truncate: truncate_esc = re.escape(truncate) if prefix and not include_prefix: prefix_esc = re.escape(prefix) pattern = '(?:{})(.*?)(?:{})'.format(prefix_esc, truncate_esc) else: pattern = '(.*?)(?:{})'.format(truncate_esc) trunc_text = re.search(pattern, gen_text, re.S) if trunc_text: gen_text = trunc_text.group(1) gen_text = gen_text.lstrip('\n') if destination_path: f.write("{}\n{}".format(gen_text, sample_delim)) if not return_as_list and not destination_path: print("{}\n{}".format(gen_text, sample_delim), end='') gen_texts.append(gen_text) if destination_path: f.close() if return_as_list: return gen_texts
def finetune(sess, dataset, steps=-1, model_name='124M', model_dir='models', combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from='latest', run_name='run1', checkpoint_dir='checkpoint', sample_every=100, sample_length=1023, sample_num=1, multi_gpu=False, save_every=1000, print_every=1, max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, optimizer='adam', overwrite=False, val_dataset=None, val_batch_size=2, val_batch_count=40, val_every=0): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ assert model_name not in ['774M', '1558M'] or multi_gpu, "Currently, a modern single GPU cannot finetune the 774M GPT-2 model or larger." SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(checkpoint_dir, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) files = [f for f in os.listdir(checkpoint_path)] for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: try: shutil.copyfile(os.path.join(model_dir, model_name, file), os.path.join(checkpoint_path, file)) except FileNotFoundError as fnf_error: print("You need to download the GPT-2 model first via download_gpt2()") raise(fnf_error) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError( "Can't get samples longer than window size: %s" % hparams.n_ctx) if model_name not in ['117M', '124M']: use_memory_saving_gradients = True only_train_transformer_layers = True accumulate_gradients = 1 context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) gpus = [] if multi_gpu: gpus = get_available_gpus() output = model.model(hparams=hparams, X=context, gpus=gpus) loss = tf.reduce_mean( input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) if val_every > 0: val_context = tf.compat.v1.placeholder(tf.int32, [val_batch_size, None]) val_output = model.model(hparams=hparams, X=val_context, gpus=gpus, reuse=True) val_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) val_loss_summary = tf.summary.scalar('val_loss', val_loss) tf_sample = sample.sample_sequence( hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40) all_vars = [v for v in tf.compat.v1.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name] if only_train_transformer_layers else all_vars if optimizer == 'adam': opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) elif optimizer == 'sgd': opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=learning_rate) if accumulate_gradients > 1: if use_memory_saving_gradients: exit("Memory saving gradients are not implemented for gradient accumulation yet.") opt = AccumulatingOptimizer( opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply) else: if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(ys=loss, xs=train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.compat.v1.summary.scalar('loss', loss) summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path) saver = tf.compat.v1.train.Saver( var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.compat.v1.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join(model_dir, model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint( os.path.join(model_dir, model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) if val_every > 0: if val_dataset: val_chunks = load_dataset(enc, val_dataset, combine) else: val_chunks = chunks print('dataset has', data_sampler.total_size, 'tokens') print('Training...') if val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_data_sampler = Sampler(val_chunks) val_batches = [[val_data_sampler.sample(1024) for _ in range(val_batch_size)] for _ in range(val_batch_count)] counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 counter_base = counter def save(): maketree(checkpoint_path) print( 'Saving', os.path.join(checkpoint_path, 'model-{}').format(counter-1)) saver.save( sess, os.path.join(checkpoint_path, 'model'), global_step=counter-1) with open(counter_path, 'w') as fp: fp.write(str(counter-1) + '\n') def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run( tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) def validation(): print('Calculating validation loss...') losses = [] for batch in tqdm(val_batches): losses.append(sess.run(val_loss, feed_dict={val_context: batch})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print( '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_val_loss)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] if overwrite and restore_from == 'latest': for file in files: if file.startswith('model') or file.startswith('events'): os.remove(os.path.join(checkpoint_path, file)) save() avg_loss = (0.0, 0.0) start_time = time.time() if steps: steps = int(steps) try: while True: if steps > 0 and counter == (counter_base + steps): save() return if (counter - 1) % save_every == 0 and counter > 1: save() if (counter - 1) % sample_every == 0 and counter > 1: generate_samples() if val_every > 0 and (counter % val_every == 0 or counter == 1): validation() if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run( opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) if counter % print_every == 0: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save()