def load_gpt2(sess, run_name="run1", checkpoint_dir="checkpoint", model_name=None, model_dir='models'): """Loads the model checkpoint or existing model into a TensorFlow session for repeated predictions. """ if model_name: checkpoint_path = os.path.join(model_dir, model_name) else: checkpoint_path = os.path.join(checkpoint_dir, run_name) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) context = tf.compat.v1.placeholder(tf.int32, [1, None]) output = model.model(hparams=hparams, X=context) ckpt = tf.train.latest_checkpoint(checkpoint_path) saver = tf.compat.v1.train.Saver(allow_empty=True) sess.run(tf.compat.v1.global_variables_initializer()) if model_name: print('Loading pretrained model', ckpt) else: print('Loading checkpoint', ckpt) saver.restore(sess, ckpt)
def load_gpt2(sess, models_dir='models/', run_name="117M", checkpoint_name=None): """Loads the model checkpoint into a TensorFlow session for repeated predictions. """ CHECKPOINT_DIR = models_dir checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) context = tf.placeholder(tf.int32, [1, None]) output = model.model(hparams=hparams, X=context) if checkpoint_name: ckpt = os.path.join(checkpoint_path, checkpoint_name) else: ckpt = tf.train.latest_checkpoint(checkpoint_path) saver = tf.train.Saver(allow_empty=True) sess.run(tf.global_variables_initializer()) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt)
def get_logitsTF(sess, checkpoint='latest', run_name="run1", checkpoint_dir="checkpoint", model_name=None, model_dir='models', multi_gpu=False): """Loads the model checkpoint or existing model into a TensorFlow session for repeated predictions. """ if model_name: checkpoint_path = os.path.join(model_dir, model_name) else: checkpoint_path = os.path.join(checkpoint_dir, run_name) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) context = tf.compat.v1.placeholder(tf.int32, [1, None]) gpus = [] if multi_gpu: gpus = get_available_gpus() output = model.model(hparams=hparams, X=context, gpus=gpus) return output['logits'].eval(session=sess)
def embed(sess, run_name='run1', destination_path="X.p", sentences=None, batch_size=1, layer_type="h", save=False, multiple_lists=False): if type(sentences) != list: prefix = [sentences] if multiple_lists: prefixes = sentences else: prefixes = [sentences] CHECKPOINT_DIR = 'checkpoint' SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) embeddings = [] context = tf.placeholder(tf.int32, [batch_size, None]) lm_output = model.model(hparams=hparams, X=context, past=None, reuse=tf.AUTO_REUSE, emb=True) with tqdm(total=len(prefixes)) as pbar: for prefix in prefixes: pref_emb = [] for p in prefix: context_tokens = enc.encode(p) e = sess.run( lm_output[layer_type], feed_dict={context: batch_size * [context_tokens]}) pref_emb.append(e[0]) embeddings.append(pref_emb) pbar.update(1) if save: f = open(destination_path, 'wb') pickle.dump(embeddings, f) f.close() if multiple_lists: return embeddings else: return embeddings[0]
def get_text_rankings(sess, string): '''takes string, returns pairs of (string, int) where the string is a part of the input and the int it it's gpt2 ranking with 0 being most likely''' prefix = string batch_size = 1 checkpoint_path = os.path.join('checkpoint', 'run1') enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) context_tokens = [50256] + enc.encode(prefix) np.random.seed(42) tf.compat.v1.set_random_seed(42) past, prev, output, all_logits = sample_sequence( hparams=hparams, length=len(context_tokens) - 1, start_token=None, batch_size=batch_size, context=context, temperature=1, top_k=0, top_p=0.0) output = output[:, 1:] pas, out, alt = sess.run( [past, output, all_logits], feed_dict={context: batch_size * [context_tokens]}) text = enc.decode(out[0]) print(text) print('generated tokens shape: ', out[0].shape) print('pas shape: ', pas.shape) def find_ranking(arr, index): "finds the postions of the element at index if the array was sorted decreasing" return (arr > arr[index]).sum() string_rank_pairs = [] for i, token in enumerate(context_tokens[1:]): logs = alt[0][i] token_string = enc.decode([token]) token_ranking = find_ranking(logs, token) string_rank_pairs.append((token_string, token_ranking)) return string_rank_pairs
def get_logits(sess, run_name='run1', checkpoint_dir='checkpoint', model_name=None, model_dir='models', prefix="<|endoftext|>", all=False): batch_size = 1 if model_name: checkpoint_path = os.path.join(model_dir, model_name) else: checkpoint_path = os.path.join(checkpoint_dir, run_name) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if prefix: context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) context_tokens = enc.encode(prefix) context_tokens = context_tokens[-1023:] def step(hparams, tokens, past=None): lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.compat.v1.AUTO_REUSE) logits = lm_output['logits'][:, :, :hparams.n_vocab] presents = lm_output['present'] presents.set_shape( model.past_shape(hparams=hparams, batch_size=batch_size)) return { 'logits': logits, 'presents': presents, } output = step(hparams, context) out = sess.run(output, feed_dict={context: batch_size * [context_tokens]}) if all: return out['logits'][ 0, :, :] # all logits starting from the second token, n logits for n tokens return out['logits'][0, -1, :] # logits for next token
def predict(sess, text, run_name='run1', checkpoint_dir='checkpoint', model_name=None, model_dir='models', sample_dir='samples', seed=None, temperature=1, top_k=2): if model_name: checkpoint_path = os.path.join(model_dir, model_name) else: checkpoint_path = os.path.join(checkpoint_dir, run_name) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) context = tf.compat.v1.placeholder(tf.int32, [1, None]) context_tokens = enc.encode(text) np.random.seed(seed) tf.compat.v1.set_random_seed(seed) start_time = time.time() proba_t = sample.sample_sequence(hparams=hparams, context=context, temperature=temperature) # print("total time 1: {}".format(time.time() - start_time)) start_time = time.time() proba = sess.run(proba_t, feed_dict={context: [context_tokens]}) # print("total time 2: {}".format(time.time() - start_time)) top_k_idxs = np.flip(np.argsort(proba))[:top_k] top_k_tokens = enc.decode(top_k_idxs).split() top_k_proba = proba[top_k_idxs] return top_k_tokens, top_k_proba
def load_gpt2( sess, checkpoint="latest", run_name="run1", checkpoint_dir="checkpoint", model_name=None, model_dir="models", multi_gpu=False, ): """Loads the model checkpoint or existing model into a TensorFlow session for repeated predictions. """ if model_name: checkpoint_path = os.path.join(model_dir, model_name) else: checkpoint_path = os.path.join(checkpoint_dir, run_name) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, "hparams.json")) as f: hparams.override_from_dict(json.load(f)) context = tf.compat.v1.placeholder(tf.int32, [1, None]) gpus = [] if multi_gpu: gpus = get_available_gpus() _ = model.model(hparams=hparams, X=context, gpus=gpus) if checkpoint == "latest": ckpt = tf.train.latest_checkpoint(checkpoint_path) else: ckpt = os.path.join(checkpoint_path, checkpoint) saver = tf.compat.v1.train.Saver(allow_empty=True) sess.run(tf.compat.v1.global_variables_initializer()) if model_name: print("Loading pretrained model", ckpt) else: print("Loading checkpoint", ckpt) saver.restore(sess, ckpt)
def my_load_gpt2(_sess, _checkpoint='latest', _run_name="run1", _checkpoint_dir="checkpoint", _model_name=None, _model_dir='models', _multi_gpu=False): """ Loads the model checkpoint or existing model into a TensorFlow session for repeated predictions. """ if _model_name: checkpoint_path = os.path.join(_model_dir, _model_name) else: checkpoint_path = os.path.join(_checkpoint_dir, _run_name) print(f"\ncheckpoint_path = {checkpoint_path}\n") hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) context = tf.compat.v1.placeholder(tf.int32, [1, None]) gpus = [] if _multi_gpu: gpus = get_available_gpus() output = model.model(hparams=hparams, X=context, gpus=gpus) if _checkpoint == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) else: ckpt = os.path.join(checkpoint_path, _checkpoint) saver = tf.compat.v1.train.Saver(allow_empty=True) _sess.run(tf.compat.v1.global_variables_initializer()) if _model_name: print(f"\nLoading pretrained model :: {ckpt}\n") else: print(f"\nLoading checkpoint :: {ckpt}\n") saver.restore(_sess, ckpt)
def generate(sess, return_as_list=False, truncate=None, destination_path=None, sample_delim='=' * 20 + '\n', prefix=None, model_name='117M', seed=None, nsamples=1, batch_size=1, length=1023, temperature=0.7, top_k=0, run_name='run1'): """Generates text from a model loaded into memory. Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py """ if batch_size is None: batch_size = 1 assert nsamples % batch_size == 0 if nsamples == 1: sample_delim = '' if prefix: context = tf.placeholder(tf.int32, [batch_size, None]) CHECKPOINT_DIR = 'checkpoint' SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) np.random.seed(seed) tf.set_random_seed(seed) output = sample.sample_sequence( hparams=hparams, length=length, start_token=enc.encoder['<|endoftext|>'] if not prefix else None, context=context if prefix else None, batch_size=batch_size, temperature=temperature, top_k=top_k)[:, 1:] if destination_path: f = open(destination_path, 'w') if prefix: context_tokens = enc.encode(prefix) generated = 0 gen_texts = [] while generated < nsamples: if not prefix: out = sess.run(output) else: out = sess.run(output, feed_dict={context: batch_size * [context_tokens]}) for i in range(batch_size): generated += 1 gen_text = enc.decode(out[i]) if prefix: gen_text = prefix[0] + gen_text if truncate: trunc_text = re.search(r'(.*?)(?:{})'.format(truncate), gen_text, re.S) if trunc_text: gen_text = trunc_text.group(1) if destination_path: f.write("{}\n{}".format(gen_text, sample_delim)) if not return_as_list and not destination_path: print("{}\n{}".format(gen_text, sample_delim)) gen_texts.append(gen_text) if destination_path: f.close() if return_as_list: return gen_texts
def generate(sess, run_name='run1', return_as_list=False, truncate=None, destination_path=None, sample_delim='=' * 20 + '\n', prefix=None, seed=None, nsamples=1, batch_size=1, length=1023, temperature=0.7, top_k=0, top_p=0.0, include_prefix=True, split_context=0.5): """Generates text from a model loaded into memory. Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py """ if batch_size is None: batch_size = 1 assert nsamples % batch_size == 0 if nsamples == 1: sample_delim = '' if prefix == '': prefix = None if not length: assert truncate is not None, "If generating a non-fixed length \ sample, must have a truncation term." CHECKPOINT_DIR = 'checkpoint' SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) context = tf.placeholder(tf.int32, [batch_size, None]) if prefix: context_tokens = [enc.encode(prefix)] * batch_size else: context_tokens = [[enc.encoder['<|endoftext|>']] for _ in range(batch_size)] np.random.seed(seed) tf.set_random_seed(seed) if destination_path: f = open(destination_path, 'w') generated = 0 gen_texts = [] while generated < nsamples: gen_text = [''] * batch_size truncated = [False] * batch_size total_tokens = 0 while False in truncated: num_tokens = 1023 - (len(context_tokens[0])) output = sample.sample_sequence( hparams=hparams, length=min(length if length else 1023, num_tokens), context=context, batch_size=batch_size, temperature=temperature, top_k=top_k, top_p=top_p )[:, 1:] out = sess.run(output, feed_dict={ context: context_tokens }) total_tokens += num_tokens for i in range(batch_size): text = enc.decode(out[i]) if prefix: text = enc.decode(context_tokens[i][:1]) + text if truncate or all(gen_text): context_tokens[i] = out[i][int(len(out[i])*(1-split_context)):] if gen_text[i] != '': split = re.split('[.!?]', gen_text[i]) text = text.partition(list(filter(None, split))[-1])[-1] if truncate: truncate_esc = re.escape(truncate) if prefix and not include_prefix: prefix_esc = re.escape(prefix) pattern = '(?:{})(.*?)(?:{})'.format(prefix_esc, truncate_esc) else: pattern = '(.*?)(?:{})'.format(truncate_esc) trunc_text = re.search(pattern, text, re.S) if trunc_text: text = trunc_text.group(1) if not truncated[i]: gen_text[i] += text.lstrip('\n') if trunc_text or (length is not None and total_tokens >= length-1): truncated[i] = True for gen in gen_text: if destination_path: f.write("{}\n{}".format(gen, sample_delim)) if not return_as_list and not destination_path: print("{}\n{}".format(gen, sample_delim), end='') gen_texts.append(gen) generated += batch_size if destination_path: f.close() if return_as_list: return gen_texts
def get_hparams(checkpoint_path): hparams = model.default_hparams() with open(os.path.join(checkpoint_path, "hparams.json")) as f: hparams.override_from_dict(json.load(f)) return hparams
def finetune( sess, dataset, steps=-1, model_name="124M", model_dir="models", combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from="latest", run_name="run1", checkpoint_dir="checkpoint", sample_every=100, sample_length=1023, sample_num=1, multi_gpu=False, save_every=1000, print_every=1, max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, optimizer="adam", overwrite=False, ): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ # assert model_name not in ['774M', '1558M'] or multi_gpu, "Currently, a modern single GPU cannot finetune the 774M GPT-2 model or larger." SAMPLE_DIR = "samples" checkpoint_path = os.path.join(checkpoint_dir, run_name) def maketree(path): try: os.makedirs(path) except FileExistsError: pass maketree(checkpoint_path) files = [f for f in os.listdir(checkpoint_path)] for file in ["hparams.json", "encoder.json", "vocab.bpe"]: try: shutil.copyfile( os.path.join(model_dir, model_name, file), os.path.join(checkpoint_path, file), ) except FileNotFoundError as fnf_error: print( "You need to download the GPT-2 model first via download_gpt2()" ) raise (fnf_error) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, "hparams.json")) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) if model_name not in ["117M", "124M"]: use_memory_saving_gradients = True only_train_transformer_layers = True accumulate_gradients = 1 context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) gpus = [] if multi_gpu: gpus = get_available_gpus() output = model.model(hparams=hparams, X=context, gpus=gpus) loss = tf.reduce_mean( input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output["logits"][:, :-1])) tf_sample = sample.sample_sequence( hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40, ) all_vars = [ v for v in tf.compat.v1.trainable_variables() if "model" in v.name ] train_vars = ([v for v in all_vars if "/h" in v.name] if only_train_transformer_layers else all_vars) if optimizer == "adam": opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) elif optimizer == "sgd": opt = tf.compat.v1.train.GradientDescentOptimizer( learning_rate=learning_rate) if accumulate_gradients > 1: if use_memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer(opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.compat.v1.summary.scalar("loss", opt_apply) else: if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(ys=loss, xs=train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.compat.v1.summary.scalar("loss", loss) summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path) saver = tf.compat.v1.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.compat.v1.global_variables_initializer()) if restore_from == "latest": ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join(model_dir, model_name)) elif restore_from == "fresh": ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print("Loading checkpoint", ckpt) saver.restore(sess, ckpt) print("Loading dataset...") chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print("dataset has", data_sampler.total_size, "tokens") print("Training...") counter = 1 counter_path = os.path.join(checkpoint_path, "counter") if os.path.exists(counter_path) and restore_from == "latest": # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, "r") as fp: counter = int(fp.read()) + 1 counter_base = counter def save(): maketree(checkpoint_path) print("Saving", os.path.join(checkpoint_path, "model-{}").format(counter - 1)) saver.save(sess, os.path.join(checkpoint_path, "model"), global_step=counter - 1) with open(counter_path, "w") as fp: fp.write(str(counter - 1) + "\n") def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run(tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = "======== SAMPLE {} ========\n{}\n".format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, "samples-{}").format(counter), "w") as fp: fp.write("\n".join(all_text)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] if overwrite and restore_from == "latest": for file in files: if file.startswith("model") or file.startswith("events"): os.remove(os.path.join(checkpoint_path, file)) save() avg_loss = (0.0, 0.0) start_time = time.time() if steps: steps = int(steps) try: while True: if steps > 0 and counter == (counter_base + steps): save() return if (counter - 1) % save_every == 0 and counter > 1: save() if (counter - 1) % sample_every == 0 and counter > 1: generate_samples() if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) if counter % print_every == 0: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( "[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}" .format( counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1], )) counter += 1 except KeyboardInterrupt: print("interrupted") save()
def finetune(sess, dataset, steps=-1, model_name='124M', model_dir='models', combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from='latest', run_name='run1', checkpoint_dir='checkpoint', sample_every=100, sample_length=1023, sample_num=1, save_every=1000, print_every=1, max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, optimizer='adam', overwrite=False, mixed_precision=False): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(checkpoint_dir, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) files = [f for f in os.listdir(checkpoint_path)] for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: try: shutil.copyfile(os.path.join(model_dir, model_name, file), os.path.join(checkpoint_path, file)) except FileNotFoundError as fnf_error: print( "You need to download the GPT-2 model first via download_gpt2()" ) raise (fnf_error) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) if model_name not in ['117M', '124M']: use_memory_saving_gradients = True only_train_transformer_layers = True accumulate_gradients = 1 context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) output = model.model(hparams=hparams, X=context) loss = tf.reduce_mean( input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) tf_sample = sample.sample_sequence(hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40) all_vars = [ v for v in tf.compat.v1.trainable_variables() if 'model' in v.name ] train_vars = [v for v in all_vars if '/h' in v.name ] if only_train_transformer_layers else all_vars if optimizer == 'adam': opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) elif optimizer == 'sgd': opt = tf.compat.v1.train.GradientDescentOptimizer( learning_rate=learning_rate) elif optimizer == 'sm3': opt = SM3Optimizer(learning_rate=learning_rate, momentum=0.9) elif optimizer == 'adafactor': opt = AdafactorOptimizer(learning_rate=learning_rate) def mp_check_tf_version(): # check TensorFlow >= 1.14 tf_version_list = tf.__version__.split(".") if int(tf_version_list[0]) < 2: return int(tf_version_list[1]) >= 14 def mp_check_tensor_core_gpu_present(): from tensorflow.python.client import device_lib # check Compute Capability >= 7.0 local_device_protos = device_lib.list_local_devices() for line in local_device_protos: if "compute capability" in str(line): compute_capability = float( line.physical_device_desc.split("compute capability: ") [-1]) if compute_capability >= 7.0: return True if mixed_precision and mp_check_tf_version( ) and mp_check_tensor_core_gpu_present(): if isinstance(opt, tf.keras.optimizers.Optimizer) or isinstance( opt, tf.compat.v1.train.Optimizer): opt = tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite( opt) else: mixed_precision = False else: mixed_precision = False print('Mixed precision enabled: %s' % mixed_precision) # Fix warning: https://github.com/tensorflow/tensorflow/blob/233d3d/tensorflow/python/training/experimental/mixed_precision.py#L355-L357 if sess is None: sess = start_tf_sess() if accumulate_gradients > 1: if use_memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer(opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply) else: if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(ys=loss, xs=train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.compat.v1.summary.scalar('loss', loss) summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path) saver = tf.compat.v1.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.compat.v1.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join(model_dir, model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print('dataset has', data_sampler.total_size, 'tokens') print('Training...') counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 counter_base = counter def save(): maketree(checkpoint_path) print('Saving', os.path.join(checkpoint_path, 'model-{}').format(counter - 1)) saver.save(sess, os.path.join(checkpoint_path, 'model'), global_step=counter - 1) with open(counter_path, 'w') as fp: fp.write(str(counter - 1) + '\n') def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run(tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] if overwrite and restore_from == 'latest': for file in files: if file.startswith('model') or file.startswith('events'): os.remove(os.path.join(checkpoint_path, file)) save() avg_loss = (0.0, 0.0) start_time = time.time() if steps: steps = int(steps) try: while True: if steps > 0 and counter == (counter_base + steps): save() return if (counter - 1) % save_every == 0 and counter > 1: save() if (counter - 1) % sample_every == 0 and counter > 1: generate_samples() if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) if counter % print_every == 0: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save()
def generate(sess, run_name='run1', checkpoint_dir='checkpoint', model_name=None, model_dir='models', sample_dir='samples', return_as_list=False, truncate=None, destination_path=None, sample_delim='=' * 20 + '\n', prefix=None, seed=None, nsamples=1, batch_size=1, length=1023, temperature=0.7, top_k=0, top_p=0.0, include_prefix=True): """Generates text from a model loaded into memory. Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py """ if batch_size is None: batch_size = 1 assert nsamples % batch_size == 0 if nsamples == 1: sample_delim = '' if prefix == '': prefix = None if model_name: checkpoint_path = os.path.join(model_dir, model_name) else: checkpoint_path = os.path.join(checkpoint_dir, run_name) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if prefix: context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) context_tokens = enc.encode(prefix) np.random.seed(seed) tf.compat.v1.set_random_seed(seed) output = sample.sample_sequence( hparams=hparams, length=min(length, 1023 - (len(context_tokens) if prefix else 0)), start_token=enc.encoder['<|endoftext|>'] if not prefix else None, context=context if prefix else None, batch_size=batch_size, temperature=temperature, top_k=top_k, top_p=top_p)[:, 1:] if destination_path: f = open(destination_path, 'w') generated = 0 gen_texts = [] while generated < nsamples: if not prefix: out = sess.run(output) else: out = sess.run(output, feed_dict={context: batch_size * [context_tokens]}) for i in range(batch_size): generated += 1 gen_text = enc.decode(out[i]) if prefix: gen_text = enc.decode(context_tokens[:1]) + gen_text if truncate: truncate_esc = re.escape(truncate) if prefix and not include_prefix: prefix_esc = re.escape(prefix) pattern = '(?:{})(.*?)(?:{})'.format( prefix_esc, truncate_esc) else: pattern = '(.*?)(?:{})'.format(truncate_esc) trunc_text = re.search(pattern, gen_text, re.S) if trunc_text: gen_text = trunc_text.group(1) gen_text = gen_text.lstrip('\n') if destination_path: f.write("{}\n{}".format(gen_text, sample_delim)) if not return_as_list and not destination_path: print("{}\n{}".format(gen_text, sample_delim), end='') gen_texts.append(gen_text) if destination_path: f.close() if return_as_list: return gen_texts
def finetune(sess, dataset, steps=-1, model_name='124M', model_dir='models', combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from='latest', run_name='run1', checkpoint_dir='checkpoint', sample_every=100, sample_length=1023, sample_num=1, multi_gpu=False, save_every=1000, print_every=1, max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, optimizer='adam', overwrite=False, val_dataset=None, val_batch_size=2, val_batch_count=40, val_every=0 ): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(checkpoint_dir, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) files = [f for f in os.listdir(checkpoint_path)] for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: try: shutil.copyfile(os.path.join(model_dir, model_name, file), os.path.join(checkpoint_path, file)) except FileNotFoundError as fnf_error: print("You need to download the GPT-2 model first via download_gpt2()") raise(fnf_error) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError( "Can't get samples longer than window size: %s" % hparams.n_ctx) context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) gpus = [] if multi_gpu: gpus = get_available_gpus() output = model.model(hparams=hparams, X=context, gpus=gpus) loss = tf.reduce_mean( input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) # validation code if val_every > 0: val_context = tf.placeholder(tf.int32, [val_batch_size, None]) val_output = model.model(hparams=hparams, X=val_context, reuse=True) # added reuse=True val_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) val_loss_summary = tf.summary.scalar('val_loss', val_loss) tf_sample = sample.sample_sequence( hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40) all_vars = [v for v in tf.compat.v1.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name] if only_train_transformer_layers else all_vars if optimizer == 'adam': opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) elif optimizer == 'sgd': opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=learning_rate) if accumulate_gradients > 1: if use_memory_saving_gradients: exit("Memory saving gradients are not implemented for gradient accumulation yet.") opt = AccumulatingOptimizer( opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply) else: if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(ys=loss, xs=train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.compat.v1.summary.scalar('loss', loss) summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path) saver = tf.compat.v1.train.Saver( var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.compat.v1.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join(model_dir, model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint( os.path.join(model_dir, model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) # validation code if val_every > 0: if val_dataset: val_chunks = load_dataset(enc, val_dataset, combine) else: val_chunks = chunks print('dataset has', data_sampler.total_size, 'tokens') print('Training...') # validation code if val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_data_sampler = Sampler(val_chunks, seed=1) val_batches = [[val_data_sampler.sample(1024) for _ in range(val_batch_size)] for _ in range(val_batch_count)] counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 counter_base = counter def save(): maketree(checkpoint_path) print( 'Saving', os.path.join(checkpoint_path, 'model-{}').format(counter-1)) saver.save( sess, os.path.join(checkpoint_path, 'model'), global_step=counter-1) with open(counter_path, 'w') as fp: fp.write(str(counter-1) + '\n') def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run( tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) # validation code def validation(): print('Calculating validation loss...') losses = [] for batch in tqdm(val_batches): losses.append(sess.run(val_loss, feed_dict={val_context: batch})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print( '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_val_loss)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] if overwrite and restore_from == 'latest': for file in files: if file.startswith('model') or file.startswith('events'): os.remove(os.path.join(checkpoint_path, file)) save() avg_loss = (0.0, 0.0) start_time = time.time() if steps: steps = int(steps) try: while True: if steps > 0 and counter == (counter_base + steps): save() return if (counter - 1) % save_every == 0 and counter > 1: save() if (counter - 1) % sample_every == 0 and counter > 1: generate_samples() # validation code if val_every > 0 and (counter % val_every == 0 or counter == 1): validation() if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run( opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) if (counter % print_every == 0) or counter == 1: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save()
def generate(sess, run_name='run1', checkpoint_dir='checkpoint', model_name=None, model_dir='models', sample_dir='samples', return_as_list=False, truncate=None, destination_path=None, sample_delim='=' * 20 + '\n', prefix=None, seed=None, nsamples=1, batch_size=1, length=1023, temperature=0.7, top_k=0, top_p=0.0, include_prefix=True, split_context=0.5): """Generates text from a model loaded into memory. Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py """ if batch_size is None: batch_size = 1 assert nsamples % batch_size == 0 if nsamples == 1: sample_delim = '' if prefix == '': prefix = None if not length: assert truncate is not None, "If generating a non-fixed length \ sample, must have a truncation term." assert 0 < split_context < 1 if model_name: checkpoint_path = os.path.join(model_dir, model_name) else: checkpoint_path = os.path.join(checkpoint_dir, run_name) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) if prefix: prefix_enc = enc.encode(prefix) np.random.seed(seed) tf.compat.v1.set_random_seed(seed) output = sample.sample_sequence( hparams=hparams, length=min(length, 1023 - (len(prefix_enc) if prefix else 0)), start_token=enc.encoder['<|endoftext|>'] if not prefix else None, context=context if prefix else None, batch_size=batch_size, temperature=temperature, top_k=top_k, top_p=top_p)[:, 1:] split_length = int(1023 * split_context) split_output_length = min(length, 1023 - split_length) split_output = sample.sample_sequence( hparams=hparams, length=split_output_length, start_token=enc.encoder['<|endoftext|>'] if not prefix else None, context=context if prefix else None, batch_size=batch_size, temperature=temperature, top_k=top_k, top_p=top_p)[:, 1:] if destination_path: f = open(destination_path, 'w') generated = 0 gen_texts = [] while generated < nsamples: gen_text = [np.array([])] * batch_size truncated = [False] * batch_size if prefix: context_tokens = [prefix_enc] * batch_size else: context_tokens = [[enc.encoder['<|endoftext|>']]] * batch_size total_tokens = len(context_tokens[0]) generated_once = False while False in truncated: num_tokens = 1023 - (len(context_tokens[0])) if generated_once: new_split_output_length = min(length - total_tokens, 1023 - split_length) if new_split_output_length != split_output_length: split_output = sample.sample_sequence( hparams=hparams, length=new_split_output_length, start_token=enc.encoder['<|endoftext|>'] if not prefix else None, context=context if prefix else None, batch_size=batch_size, temperature=temperature, top_k=top_k, top_p=top_p)[:, 1:] out = sess.run(split_output, feed_dict={context: context_tokens}) else: out = sess.run(output, feed_dict={context: context_tokens}) total_tokens += num_tokens for i in range(batch_size): text = out[i] trunc_text = "" if prefix: text = np.append(context_tokens[i][:1], text) if truncate or all(gen_text): context_tokens[i] = out[i][(1023 - split_length - 1):] if generated_once: text = out[i][split_length:] if truncate: to_trunc = enc.decode(text) truncate_esc = re.escape(truncate) if prefix and not include_prefix: prefix_esc = re.escape(prefix) pattern = '(?:{})(.*?)(?:{})'.format( prefix_esc, truncate_esc) else: pattern = '(.*?)(?:{})'.format(truncate_esc) trunc_text = re.search(pattern, to_trunc, re.S) if trunc_text: text = enc.encode(trunc_text.group(1)) # better to re-encode here then decode every generation cycle, I think if not truncated[i]: gen_text[i] = np.concatenate((gen_text[i], text), axis=None) if trunc_text or (length is not None and total_tokens >= length - 1): truncated[i] = True gen = enc.decode(gen_text[i]).lstrip('\n') if destination_path: f.write("{}\n{}".format(gen, sample_delim)) if not return_as_list and not destination_path: print("{}\n{}".format(gen, sample_delim), end='') gen_texts.append(gen) generated_once = True generated += batch_size if destination_path: f.close() if return_as_list: return gen_texts
def generate(sess, run_name='run1', return_as_list=False, truncate=None, destination_path=None, sample_delim='=' * 20 + '\n', prefix=None, seed=None, nsamples=1, batch_size=1, length=1023, temperature=0.7, top_k=0, top_p=0.0, include_prefix=True, split_context=0.5, batch_prefix=None): """Generates text from a model loaded into memory. Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py """ if batch_size is None: batch_size = 1 assert nsamples % batch_size == 0 if nsamples == 1: sample_delim = '' assert not (prefix and batch_prefix) assert not batch_prefix or len(batch_prefix) == batch_size if prefix == '': prefix = None if batch_prefix == []: batch_prefix = None if not length: assert truncate is not None, "If generating a non-fixed length \ sample, must have a truncation term." CHECKPOINT_DIR = 'checkpoint' SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) context = tf.placeholder(tf.int32, [batch_size, None]) if prefix: context_tokens = [enc.encode(prefix)] * batch_size elif batch_prefix: # context_tokens = [enc.encode(batch_prefix[0])] * batch_size context_tokens = [enc.encode(pre) for pre in batch_prefix] # print([enc.decode(c) for c in context_tokens]) # assert all(len(context_tokens[0]) == len(c) for c in context_tokens) ml = max([len(p) for p in context_tokens]) for p in context_tokens: while len(p) < ml: p.insert(0, 0) # padding front :) else: context_tokens = [[enc.encoder['<|endoftext|>']] for _ in range(batch_size)] np.random.seed(seed) tf.set_random_seed(seed) if destination_path: f = open(destination_path, 'w') generated = 0 gen_texts = [] while generated < nsamples: # gen_text = [''] * batch_size gen_text = [[]] * batch_size truncated = [False] * batch_size total_tokens = 0 while False in truncated: num_tokens = 1023 - (len(context_tokens[0])) output = sample.sample_sequence( hparams=hparams, length=min(length if length else 1023, num_tokens), context=context, batch_size=batch_size, temperature=temperature, top_k=top_k, top_p=top_p )[:, 1:] out = sess.run(output, feed_dict={ context: context_tokens }) total_tokens += num_tokens for i in range(batch_size): if truncated[i]: continue text = out[i] trunc_text = "" #added to patch to fix unassigned variable if prefix or batch_prefix: text = np.append(context_tokens[i][:1], text) if truncate or all(gen_text): context_tokens[i] = out[i][int(len(out[i])*(1-split_context)):] if gen_text[i]: text = out[i][int(len(out[i])*(split_context)):] # OLD, leaving for now for xref # split = re.split('[.!?]', gen_text[i]) # text = text.partition(list(filter(None, split))[-1])[-1] # ok so the idea is, split up gen_text, which is cumulative # then you get the latest in gen_text, and find where it is in the new text # then you split the new text on that, and get everything after it. # but it seems like it sometimes chooses an empty string... # yo just leave it as tokens until the end tho and use count if truncate: to_trunc = enc.decode(text) truncate_esc = re.escape(truncate) if prefix and not include_prefix: prefix_esc = re.escape(prefix) pattern = '(?:{})(.*?)(?:{})'.format(prefix_esc, truncate_esc) else: pattern = '(.*?)(?:{})'.format(truncate_esc) trunc_text = re.search(pattern, to_trunc, re.S) if trunc_text: text = enc.encode(trunc_text.group(1)) #inefficient, but let's just get this working for now if not truncated[i]: gen_text[i] += [text] #.lstrip('\n') if trunc_text or (length is not None and total_tokens >= length-1): truncated[i] = True for gen in gen_text: gen = [enc.decode(g).lstrip('\n') for g in gen] gen = ''.join(gen) if destination_path: f.write("{}\n{}".format(gen, sample_delim)) if not return_as_list and not destination_path: print("{}\n{}".format(gen, sample_delim), end='') gen_texts.append(gen) generated += batch_size if destination_path: f.close() if return_as_list: return gen_texts
def finetune(sess, dataset, steps=-1, model_name='117M', combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from='latest', run_name='run1', sample_every=100, sample_length=1023, sample_num=1, save_every=1000, print_every=1, max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, model_load=False): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ CHECKPOINT_DIR = 'checkpoint' SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) if not model_load: for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: shutil.copyfile(os.path.join('models', model_name, file), os.path.join(checkpoint_path, file)) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) if model_name != '117M': use_memory_saving_gradients = True only_train_transformer_layers = True accumulate_gradients = 1 context = tf.placeholder(tf.int32, [batch_size, None]) output = model.model(hparams=hparams, X=context) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) tf_sample = sample.sample_sequence(hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40) all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name ] if only_train_transformer_layers else all_vars if accumulate_gradients > 1: if use_memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer( opt=tf.train.AdamOptimizer(learning_rate=learning_rate), var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: opt = tf.train.AdamOptimizer(learning_rate=learning_rate) if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_log = tf.summary.FileWriter(checkpoint_path) saver = tf.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join('models', model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) if model_load: return print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print('dataset has', data_sampler.total_size, 'tokens') print('Training...') counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 counter_base = counter def save(): maketree(checkpoint_path) print('Saving', os.path.join(checkpoint_path, 'model-{}').format(counter - 1)) saver.save(sess, os.path.join(checkpoint_path, 'model'), global_step=counter - 1) with open(counter_path, 'w') as fp: fp.write(str(counter - 1) + '\n') def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run(tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] avg_loss = (0.0, 0.0) start_time = time.time() try: while True: if steps > 0 and counter == (counter_base + steps): save() return if (counter - 1) % save_every == 0 and counter > 1: save() if (counter - 1) % sample_every == 0 and counter > 1: generate_samples() if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) if counter % print_every == 0: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save()
flags.DEFINE_integer( 'num_tpu_cores', 8, 'Only used if `use_tpu` is True. Total number of TPU cores to use.', ) tf.flags.DEFINE_string('master', None, '[Optional] TensorFlow master URL.') flags.DEFINE_bool('do_train', False, 'Whether to run training.') flags.DEFINE_bool('do_eval', False, 'Whether to run eval on the dev set.') flags.DEFINE_bool('use_tpu', True, 'Whether to use TPU or GPU/CPU.') # https://storage.googleapis.com/gpt-2/models/117M/hparams.json hparams = model.default_hparams() with open('base-hparams.json') as f: hparams.override_from_dict(json.load(f)) def get_assignment_map_from_checkpoint(tvars, init_checkpoint): assignment_map = {} initialized_variable_names = {} name_to_variable = collections.OrderedDict() for var in tvars: name = var.name m = re.match('^(.*):\\d+$', name) if m is not None: name = m.group(1) name_to_variable[name] = var
def finetune(sess, dataset, steps=-1, model_name='124M', model_dir='models', combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from='latest', run_name='run1', checkpoint_dir='checkpoint', sample_every=100, sample_length=1023, sample_num=1, multi_gpu=False, save_every=1000, print_every=1, max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, optimizer='adam', overwrite=False): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ assert model_name not in [ '774M', '1558M' ] or multi_gpu, "[ERROR] Currently, a modern single GPU cannot finetune the 774M GPT-2 model or larger." SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(checkpoint_dir, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) files = [f for f in os.listdir(checkpoint_path)] for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: try: shutil.copyfile(os.path.join(model_dir, model_name, file), os.path.join(checkpoint_path, file)) except FileNotFoundError as fnf_error: print( "[ERROR] You need to download the GPT-2 model first via download_gpt2()" ) raise (fnf_error) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError( "[ERROR] Can't get samples longer than window size: %s" % hparams.n_ctx) if model_name not in ['117M', '124M']: use_memory_saving_gradients = True only_train_transformer_layers = True accumulate_gradients = 1 context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) gpus = [] if multi_gpu: gpus = get_available_gpus() output = model.model(hparams=hparams, X=context, gpus=gpus) loss = tf.reduce_mean( input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) tf_sample = sample.sample_sequence(hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40) all_vars = [ v for v in tf.compat.v1.trainable_variables() if 'model' in v.name ] train_vars = [v for v in all_vars if '/h' in v.name ] if only_train_transformer_layers else all_vars if optimizer == 'adam': opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) elif optimizer == 'sgd': opt = tf.compat.v1.train.GradientDescentOptimizer( learning_rate=learning_rate) if accumulate_gradients > 1: if use_memory_saving_gradients: exit( "[INFO] Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer(opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply) else: if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(ys=loss, xs=train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.compat.v1.summary.scalar('loss', loss) summaries = [summary_loss] summary_op = tf.summary.merge(summaries) summary_writer = tf.compat.v1.summary.FileWriter(checkpoint_path, sess.graph) saver = tf.compat.v1.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.compat.v1.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join(model_dir, model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('[INFO] Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('[INFO] Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print('[INFO] dataset has', data_sampler.total_size, 'tokens') print('[INFO] Training...') counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 counter_base = counter def save(): maketree(checkpoint_path) print('[INFO] Saving', os.path.join(checkpoint_path, 'model-{}').format(counter - 1)) saver.save(sess, os.path.join(checkpoint_path, 'model'), global_step=counter - 1) with open(counter_path, 'w') as fp: fp.write(str(counter - 1) + '\n') def generate_samples(header=True): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run(tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) if header: text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) return text def sample_batch(sample_size=1024): return [data_sampler.sample(sample_size) for _ in range(batch_size)] if overwrite and restore_from == 'latest': for file in files: if file.startswith('model') or file.startswith('events'): os.remove(os.path.join(checkpoint_path, file)) save() avg_loss = (0.0, 0.0) start_time = time.time() best_loss = 1000.00 if steps: steps = int(steps) try: while True: t = time.time() - start_time if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_op)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_op), feed_dict={context: sample_batch()}) monitoring.push_metric('loss', v_loss, counter, t) summary_writer.add_summary(v_summary, counter) if steps > 0 and counter == (counter_base + steps): save() return if (counter - 1) % save_every == 0 and counter > 1: save() sample_text = generate_samples(False) summary_sample = tf.compat.v1.summary.text( 'generated_sample', tf.convert_to_tensor(sample_text)) text = sess.run(summary_sample) summary_writer.add_summary(text, counter) monitoring.push_text('generated_text', sample_text, counter, v_loss, t) if v_loss < best_loss and v_loss > 1.00: best_loss = v_loss loss_text = generate_samples(False) loss_sample_text = "===== LOSS : " + str( best_loss) + "=====" + loss_text summary_loss_sample = tf.compat.v1.summary.text( 'best_loss_sample', tf.convert_to_tensor(loss_sample_text)) loss_text_ = sess.run(summary_loss_sample) summary_writer.add_summary(loss_text_, counter) monitoring.push_text('best_loss_text', loss_text, counter, best_loss, t) monitoring.push_metric('best_loss', best_loss, counter, t) if counter % print_every == 0: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) monitoring.push_metric('avg_loss', avg_loss[0] / avg_loss[1], counter, t) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=t, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('[STOP] interrupted') save()
def one_lr_cycle(sess, dataset, steps=10000, model_name='117M', combine=50000, batch_size=1, intial_lr=1e-10, final_lr=1, accumulate_gradients=5, restore_from='fresh', run_name='run1', max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, overwrite=False): """Does one LR half-cycle from initial to final over steps iterations using CLR algorithm https://github.com/bckenstler/CLR Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ CHECKPOINT_DIR = 'checkpoint' checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) files = [f for f in os.listdir(checkpoint_path)] for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: if file not in files: try: shutil.copyfile(os.path.join('models', model_name, file), os.path.join(checkpoint_path, file)) except FileNotFoundError as fnf_error: print( "You need to download the GPT-2 model first via download_gpt2()" ) raise (fnf_error) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if model_name != '117M': use_memory_saving_gradients = True only_train_transformer_layers = True accumulate_gradients = 1 context = tf.placeholder(tf.int32, [batch_size, None]) output = model.model(hparams=hparams, X=context) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) current_iter = 0 learning_rate = tf.placeholder(tf.float32, shape=[]) def get_lr(): cycle = np.floor(1 + current_iter / (2 * steps)) x = np.abs(current_iter / steps - 2 * cycle + 1) lr = intial_lr + (final_lr - intial_lr) * np.maximum( 0, (1 - x)) # * scale_fn(x) return lr all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name ] if only_train_transformer_layers else all_vars if accumulate_gradients > 1: if use_memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer( opt=tf.train.AdamOptimizer(learning_rate=learning_rate), var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: opt = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_log = tf.summary.FileWriter(checkpoint_path) saver = tf.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join('models', model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print('dataset has', data_sampler.total_size, 'tokens') print('Training...') counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 counter_base = counter def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] if overwrite and restore_from == 'latest': for file in files: if file.startswith('model') or file.startswith('events'): os.remove(os.path.join(checkpoint_path, file)) start_time = time.time() try: while True: if steps > 0 and counter == (counter_base + steps): return if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={ context: sample_batch(), learning_rate: get_lr() }) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={ context: sample_batch(), learning_rate: get_lr() }) summary_log.add_summary(v_summary, counter) print('[{counter} | {time:2.2f}] loss={loss:3.14f} lr={lr:2.14f}'. format(counter=counter, time=time.time() - start_time, loss=v_loss, lr=get_lr())) counter += 1 current_iter += 1 except KeyboardInterrupt: print('interrupted')
def main(_): tf.logging.set_verbosity(tf.logging.INFO) logger = tf.get_logger() logger.propagate = False if not FLAGS.do_train and not FLAGS.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") hparams = model.default_hparams() with tf.gfile.GFile(FLAGS.config_file) as f: hparams.override_from_dict(json.load(f)) tf.gfile.MakeDirs(FLAGS.output_dir) input_files = [] for input_pattern in FLAGS.input_file.split(","): input_files.extend(tf.gfile.Glob(input_pattern)) # tf.logging.info("*** Input Files ***") # for input_file in input_files: # tf.logging.info(" %s" % input_file) tpu_cluster_resolver = None if FLAGS.use_tpu and FLAGS.tpu_name: tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 run_config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, master=FLAGS.master, model_dir=FLAGS.output_dir, save_checkpoints_steps=FLAGS.save_checkpoints_steps, keep_checkpoint_max=FLAGS.keep_checkpoint_max, tpu_config=tf.contrib.tpu.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop, num_shards=FLAGS.num_tpu_cores, per_host_input_for_training=is_per_host, ), ) model_fn = model_fn_builder( hparams=hparams, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=FLAGS.num_train_steps, num_warmup_steps=FLAGS.num_warmup_steps, use_tpu=FLAGS.use_tpu, optimizer=FLAGS.optimizer, poly_power=FLAGS.poly_power, start_warmup_step=FLAGS.start_warmup_step, use_memory_saving_gradients=FLAGS.use_memory_saving_gradients) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. estimator = tf.contrib.tpu.TPUEstimator( use_tpu=FLAGS.use_tpu, model_fn=model_fn, config=run_config, train_batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.eval_batch_size, ) if FLAGS.do_train: tf.logging.info("***** Running training *****") tf.logging.info(" Batch size = %d", FLAGS.batch_size) train_input_fn = input_fn_builder( input_files=input_files, max_seq_length=FLAGS.max_seq_length, is_training=True, ) estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) if FLAGS.do_eval: tf.logging.info("***** Running evaluation *****") tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) eval_input_fn = input_fn_builder( input_files=input_files, max_seq_length=FLAGS.max_seq_length, is_training=False, ) result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") with tf.gfile.GFile(output_eval_file, "w") as writer: tf.logging.info("***** Eval results *****") for key in sorted(result.keys()): tf.logging.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key])))
def model_fn_builder( init_checkpoint, learning_rate, num_train_steps, num_warmup_steps, config, ): hparams = model.default_hparams() with tf.gfile.GFile(config, "r") as reader: text = reader.read() config = json.loads(text) hparams.override_from_dict(config) def model_fn(features, labels, mode, params): tf.logging.info('*** Features ***') for name in sorted(features.keys()): tf.logging.info(' name = %s, shape = %s' % (name, features[name].shape)) input_ids = features['input_ids'] is_training = mode == tf.estimator.ModeKeys.TRAIN output = model.model(hparams=hparams, X=input_ids) loss = tf.reduce_mean( input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits( labels=input_ids[:, 1:], logits=output['logits'][:, :-1])) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: ( assignment_map, initialized_variable_names, ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint) def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold tf.logging.info('**** Trainable Variables ****') for var in tvars: init_string = '' if var.name in initialized_variable_names: init_string = ', *INIT_FROM_CKPT*' tf.logging.info(' name = %s, shape = %s%s', var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer(loss, learning_rate, num_train_steps, num_warmup_steps, True) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn, ) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(loss, input_ids, output): next_sentence_predictions = tf.argmax(next_sentence_log_probs, axis=-1, output_type=tf.int32) next_sentence_labels = tf.reshape(input_ids, [-1]) next_sentence_accuracy = tf.metrics.accuracy( labels=next_sentence_labels, predictions=next_sentence_predictions, ) next_sentence_mean_loss = tf.metrics.mean(values=loss) return { 'next_sentence_accuracy': next_sentence_accuracy, 'next_sentence_loss': next_sentence_mean_loss, } eval_metrics = (metric_fn, [loss, input_ids, output]) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn, ) else: raise ValueError('Only TRAIN and EVAL modes are supported: %s' % (mode)) return output_spec return model_fn