def load_gpt2(sess, run_name="run1", checkpoint_dir="checkpoint", model_name=None, model_dir='models', multi_gpu=False, scope=None): """Loads the model checkpoint or existing model into a TensorFlow session for repeated predictions. """ if scope != None: with tf.compat.v1.variable_scope(scope): if model_name: checkpoint_path = os.path.join(model_dir, model_name) else: checkpoint_path = os.path.join(checkpoint_dir, run_name) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) context = tf.compat.v1.placeholder(tf.int32, [1, None]) gpus = [] if multi_gpu: gpus = get_available_gpus() output = model.model(hparams=hparams, X=context, gpus=gpus) ckpt = tf.train.latest_checkpoint(checkpoint_path) saver = tf.compat.v1.train.Saver(allow_empty=True) sess.run(tf.compat.v1.global_variables_initializer()) if model_name: print('Loading pretrained model', ckpt) else: print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) else: if model_name: checkpoint_path = os.path.join(model_dir, model_name) else: checkpoint_path = os.path.join(checkpoint_dir, run_name) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) context = tf.compat.v1.placeholder(tf.int32, [1, None]) gpus = [] if multi_gpu: gpus = get_available_gpus() output = model.model(hparams=hparams, X=context, gpus=gpus) ckpt = tf.train.latest_checkpoint(checkpoint_path) saver = tf.compat.v1.train.Saver(allow_empty=True) sess.run(tf.compat.v1.global_variables_initializer()) if model_name: print('Loading pretrained model', ckpt) else: print('Loading checkpoint', ckpt) saver.restore(sess, ckpt)
def finetune(sess, dataset, steps=-1, model_name='124M', model_dir='models', combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from='latest', run_name='run1', checkpoint_dir='checkpoint', sample_every=100, sample_length=1023, sample_num=1, multi_gpu=False, save_every=1000, print_every=1, max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, optimizer='adam', overwrite=False): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ # assert model_name not in ['774M', '1558M'] or multi_gpu, "Currently, a modern single GPU cannot finetune the 774M GPT-2 model or larger." SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(checkpoint_dir, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) files = [f for f in os.listdir(checkpoint_path)] for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: try: shutil.copyfile(os.path.join(model_dir, model_name, file), os.path.join(checkpoint_path, file)) except FileNotFoundError as fnf_error: print( "You need to download the GPT-2 model first via download_gpt2()" ) raise (fnf_error) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) if model_name not in ['117M', '124M']: use_memory_saving_gradients = True only_train_transformer_layers = True accumulate_gradients = 1 context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) gpus = [] if multi_gpu: gpus = get_available_gpus() output = model.model(hparams=hparams, X=context, gpus=gpus) loss = tf.reduce_mean( input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) tf_sample = sample.sample_sequence(hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40) all_vars = [ v for v in tf.compat.v1.trainable_variables() if 'model' in v.name ] train_vars = [v for v in all_vars if '/h' in v.name ] if only_train_transformer_layers else all_vars if optimizer == 'adam': opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) elif optimizer == 'sgd': opt = tf.compat.v1.train.GradientDescentOptimizer( learning_rate=learning_rate) if accumulate_gradients > 1: if use_memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer(opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply) else: if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(ys=loss, xs=train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.compat.v1.summary.scalar('loss', loss) summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path) saver = tf.compat.v1.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.compat.v1.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join(model_dir, model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print('dataset has', data_sampler.total_size, 'tokens') print('Training...') counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 counter_base = counter def save(): maketree(checkpoint_path) print('Saving', os.path.join(checkpoint_path, 'model-{}').format(counter - 1)) saver.save(sess, os.path.join(checkpoint_path, 'model'), global_step=counter - 1) with open(counter_path, 'w') as fp: fp.write(str(counter - 1) + '\n') def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run(tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] if overwrite and restore_from == 'latest': for file in files: if file.startswith('model') or file.startswith('events'): os.remove(os.path.join(checkpoint_path, file)) save() avg_loss = (0.0, 0.0) start_time = time.time() if steps: steps = int(steps) try: while True: if steps > 0 and counter == (counter_base + steps): save() return if (counter - 1) % save_every == 0 and counter > 1: save() if (counter - 1) % sample_every == 0 and counter > 1: generate_samples() if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) if counter % print_every == 0: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save()
def one_lr_cycle(sess, dataset, steps=10000, model_name='117M', combine=50000, batch_size=1, intial_lr=1e-10, final_lr=1, accumulate_gradients=5, restore_from='fresh', run_name='run1', max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, overwrite=False): """Does one LR half-cycle from initial to final over steps iterations using CLR algorithm https://github.com/bckenstler/CLR Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ CHECKPOINT_DIR = 'checkpoint' checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) files = [f for f in os.listdir(checkpoint_path)] for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: if file not in files: try: shutil.copyfile(os.path.join('models', model_name, file), os.path.join(checkpoint_path, file)) except FileNotFoundError as fnf_error: print( "You need to download the GPT-2 model first via download_gpt2()" ) raise (fnf_error) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if model_name != '117M': use_memory_saving_gradients = True only_train_transformer_layers = True accumulate_gradients = 1 context = tf.placeholder(tf.int32, [batch_size, None]) output = model.model(hparams=hparams, X=context) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) current_iter = 0 learning_rate = tf.placeholder(tf.float32, shape=[]) def get_lr(): cycle = np.floor(1 + current_iter / (2 * steps)) x = np.abs(current_iter / steps - 2 * cycle + 1) lr = intial_lr + (final_lr - intial_lr) * np.maximum( 0, (1 - x)) # * scale_fn(x) return lr all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name ] if only_train_transformer_layers else all_vars if accumulate_gradients > 1: if use_memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer( opt=tf.train.AdamOptimizer(learning_rate=learning_rate), var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: opt = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_log = tf.summary.FileWriter(checkpoint_path) saver = tf.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join('models', model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print('dataset has', data_sampler.total_size, 'tokens') print('Training...') counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 counter_base = counter def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] if overwrite and restore_from == 'latest': for file in files: if file.startswith('model') or file.startswith('events'): os.remove(os.path.join(checkpoint_path, file)) start_time = time.time() try: while True: if steps > 0 and counter == (counter_base + steps): return if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={ context: sample_batch(), learning_rate: get_lr() }) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={ context: sample_batch(), learning_rate: get_lr() }) summary_log.add_summary(v_summary, counter) print('[{counter} | {time:2.2f}] loss={loss:3.14f} lr={lr:2.14f}'. format(counter=counter, time=time.time() - start_time, loss=v_loss, lr=get_lr())) counter += 1 current_iter += 1 except KeyboardInterrupt: print('interrupted')
def model_fn(features, labels, mode, params): tf.logging.info('*** Features ***') for name in sorted(features.keys()): tf.logging.info(' name = %s, shape = %s' % (name, features[name].shape)) input_ids = features['input_ids'] is_training = mode == tf.estimator.ModeKeys.TRAIN output = model.model(hparams=hparams, X=input_ids) loss = tf.reduce_mean( input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits( labels=input_ids[:, 1:], logits=output['logits'][:, :-1])) tvars = tf.trainable_variables() initialized_variable_names = {} ( assignment_map, initialized_variable_names, ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint) def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold tf.logging.info('**** Trainable Variables ****') for var in tvars: init_string = '' if var.name in initialized_variable_names: init_string = ', *INIT_FROM_CKPT*' tf.logging.info(' name = %s, shape = %s%s', var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer(loss, learning_rate, num_train_steps, num_warmup_steps, True) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn, ) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(loss, input_ids, output): next_sentence_predictions = tf.argmax(next_sentence_log_probs, axis=-1, output_type=tf.int32) next_sentence_labels = tf.reshape(input_ids, [-1]) next_sentence_accuracy = tf.metrics.accuracy( labels=next_sentence_labels, predictions=next_sentence_predictions, ) next_sentence_mean_loss = tf.metrics.mean(values=loss) return { 'next_sentence_accuracy': next_sentence_accuracy, 'next_sentence_loss': next_sentence_mean_loss, } eval_metrics = (metric_fn, [loss, input_ids, output]) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn, ) else: raise ValueError('Only TRAIN and EVAL modes are supported: %s' % (mode)) return output_spec
def finetune(sess, dataset, steps=-1, model_name='117M', combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from='latest', run_name='run1', sample_every=100, sample_length=1023, sample_num=1, save_every=1000, print_every=1, max_checkpoints=1, model_load=False): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ CHECKPOINT_DIR = 'checkpoint' SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) if not model_load: for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: shutil.copyfile(os.path.join('models', model_name, file), os.path.join(checkpoint_path, file)) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) context = tf.placeholder(tf.int32, [batch_size, None]) output = model.model(hparams=hparams, X=context) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) tf_sample = sample.sample_sequence(hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40) train_vars = [v for v in tf.trainable_variables() if 'model' in v.name] if accumulate_gradients > 1: opt = AccumulatingOptimizer( opt=tf.train.AdamOptimizer(learning_rate=learning_rate), var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: opt_apply = tf.train.AdamOptimizer( learning_rate=learning_rate).minimize(loss, var_list=train_vars) summary_loss = tf.summary.scalar('loss', loss) summary_log = tf.summary.FileWriter(checkpoint_path) saver = tf.train.Saver(var_list=train_vars, max_to_keep=max_checkpoints) sess.run(tf.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join('models', model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) if model_load: return print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print('dataset has', data_sampler.total_size, 'tokens') print('Training...') counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 def save(): maketree(checkpoint_path) print('Saving', os.path.join(checkpoint_path, 'model-{}').format(counter)) saver.save(sess, os.path.join(checkpoint_path, 'model'), global_step=counter) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run(tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] avg_loss = (0.0, 0.0) start_time = time.time() try: while True: if counter == steps: save() return if counter % save_every == 0: save() if counter % sample_every == 0: generate_samples() if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) if counter % print_every == 0: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save()
def model_fn(features, labels, mode, params): tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] output = model.model(hparams=hparams, X=input_ids) loss = tf.reduce_mean( input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits( labels=input_ids[:, 1:], logits=output["logits"][:, :-1])) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: ( assignment_map, initialized_variable_names, ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, optimizer, poly_power, start_warmup_step, use_memory_saving_gradients=use_memory_saving_gradients) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn, ) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(loss): """Evaluation metric Fn which runs on CPU.""" perplexity = tf.exp(tf.reduce_mean(loss)) bpc = tf.reduce_mean(loss) / tf.constant(math.log(2)) return { "perplexity": tf.metrics.mean(perplexity), "bpc": tf.metrics.mean(bpc), } if FLAGS.use_tpu: with tf.colocate_with(loss): loss = tf.contrib.tpu.cross_replica_sum(loss) \ / FLAGS.num_tpu_cores metric_loss = tf.tile(tf.reshape(loss, [1, 1]), [FLAGS.eval_batch_size, 1]) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=(metric_fn, [metric_loss]), scaffold_fn=scaffold_fn) # eval_metrics = (metric_fn, {"loss":loss}) # output_spec = tf.contrib.tpu.TPUEstimatorSpec( # mode=mode, # loss=loss, # eval_metrics=eval_metrics, # scaffold_fn=scaffold_fn, # ) else: raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) return output_spec
def finetune( sess, dataset, steps=-1, model_name="124M", model_dir="models", combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from="latest", run_name="run1", checkpoint_dir="checkpoint", sample_every=100, sample_length=1023, sample_num=1, sample_prefix="", multi_gpu=False, save_every=1000, print_every=1, sample_dir="samples", max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, optimizer="adafactor", overwrite=False, ): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ checkpoint_path = os.path.join(checkpoint_dir, run_name) os.makedirs(checkpoint_path, exist_ok=True) files = os.listdir(checkpoint_path) for file_name in ["hparams.json", "encoder.json", "vocab.bpe"]: try: shutil.copyfile( os.path.join(model_dir, model_name, file_name), os.path.join(checkpoint_path, file_name), ) except FileNotFoundError as fnf_error: raise RuntimeError( "You need to download the GPT-2 model first via download_gpt2()" ) from fnf_error enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, "hparams.json")) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError( "Can't get samples longer than window size: %s" % hparams.n_ctx ) if model_name > "124M": use_memory_saving_gradients = True only_train_transformer_layers = True accumulate_gradients = 1 context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) if multi_gpu: gpus = get_available_gpus() else: gpus = [] output = model.model(hparams=hparams, X=context, gpus=gpus) loss = tf.reduce_mean( input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output["logits"][:, :-1] ) ) tf_sample = sample.sample_sequence( hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40, ) all_vars = [v for v in tf.compat.v1.trainable_variables() if "model" in v.name] train_vars = ( [v for v in all_vars if "/h" in v.name] if only_train_transformer_layers else all_vars ) if optimizer == "adam": opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) elif optimizer == "sgd": opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=learning_rate) elif optimizer == "adafactor": params = {} params["decay_type"] = "adam" params["beta1"] = 0.0 params["beta2"] = 0.999 if params["decay_type"] == "adam": decay_rate = adafactor_decay_rate_adam(params["beta2"]) elif params["decay_type"] == "pow": decay_rate = adafactor_decay_rate_pow(params["decay_exponent"]) else: raise ValueError("unknown optimizer_adafactor_decay_type") if not "weight_decay" in params.keys(): opt = AdafactorOptimizer( learning_rate=learning_rate, decay_rate=decay_rate, beta1=params["beta1"], name="Adafactor", ) else: AdafactorWOptimizer = tf.contrib.opt.extend_with_decoupled_weight_decay( AdafactorOptimizer ) opt = AdafactorWOptimizer( weight_decay=params["weight_decay"] * learning_rate, learning_rate=learning_rate, decay_rate=decay_rate, beta1=params["beta1"], name="AdafactorW", ) else: raise ValueError(f"Unknown optimizer {optimizer}") if accumulate_gradients > 1: if use_memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer(opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.compat.v1.summary.scalar("loss", opt_apply) else: if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(ys=loss, xs=train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.compat.v1.summary.scalar("loss", loss) summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path) saver = tf.compat.v1.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.compat.v1.global_variables_initializer()) if restore_from == "latest": ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name)) elif restore_from == "fresh": ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print("Loading checkpoint", ckpt) saver.restore(sess, ckpt) print("Loading dataset...") chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print(f"Dataset has {data_sampler.total_size} tokens.") if not data_sampler.total_size: raise ValueError("Dataset is empty.") print("Training...") counter = 1 counter_path = os.path.join(checkpoint_path, "counter") if os.path.exists(counter_path) and restore_from == "latest": # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, "r") as fp: counter = int(fp.read()) + 1 counter_base = counter def save(): os.makedirs(checkpoint_path, exist_ok=True) print("Saving", os.path.join(checkpoint_path, "model-{}").format(counter - 1)) saver.save( sess, os.path.join(checkpoint_path, "model"), global_step=counter - 1 ) with open(counter_path, "w") as fp: fp.write(str(counter - 1) + "\n") def generate_samples(): if sample_prefix: context_tokens = enc.encode(sample_prefix) else: context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run( tf_sample, feed_dict={context: batch_size * [context_tokens]} ) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = "======== SAMPLE {} ========\n{}\n".format(index + 1, text) all_text.append(text) index += 1 print(text) os.makedirs(os.path.join(sample_dir, run_name), exist_ok=True) with open( os.path.join(sample_dir, run_name, "samples-{}").format(counter), "w" ) as fp: fp.write("\n".join(all_text)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] if overwrite and restore_from == "latest": for file in files: if file.startswith("model") or file.startswith("events"): os.remove(os.path.join(checkpoint_path, file)) save() avg_loss = (0.0, 0.0) start_time = time.time() if steps: steps = int(steps) try: while True: if steps > 0 and counter == (counter_base + steps): save() return if (counter - 1) % save_every == 0 and counter > 1: save() if (counter - 1) % sample_every == 0 and counter > 1: generate_samples() if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()} ) summary_log.add_summary(v_summary, counter) if counter % print_every == 0: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( "[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}".format( counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1], ) ) counter += 1 except KeyboardInterrupt: print("interrupted") save()