def finetune(sess, dataset, steps=-1, model_name='117M', combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from='latest', run_name='run1', sample_every=100, sample_length=1023, sample_num=1, save_every=1000, print_every=1, max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, model_load=False): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ CHECKPOINT_DIR = 'checkpoint' SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) if not model_load: for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: shutil.copyfile(os.path.join('models', model_name, file), os.path.join(checkpoint_path, file)) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) if model_name != '117M': use_memory_saving_gradients = True only_train_transformer_layers = True accumulate_gradients = 1 context = tf.placeholder(tf.int32, [batch_size, None]) output = model.model(hparams=hparams, X=context) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) tf_sample = sample.sample_sequence(hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40) all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name ] if only_train_transformer_layers else all_vars if accumulate_gradients > 1: if use_memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer( opt=tf.train.AdamOptimizer(learning_rate=learning_rate), var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: opt = tf.train.AdamOptimizer(learning_rate=learning_rate) if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_log = tf.summary.FileWriter(checkpoint_path) saver = tf.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join('models', model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) if model_load: return print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print('dataset has', data_sampler.total_size, 'tokens') print('Training...') counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 counter_base = counter def save(): maketree(checkpoint_path) print('Saving', os.path.join(checkpoint_path, 'model-{}').format(counter - 1)) saver.save(sess, os.path.join(checkpoint_path, 'model'), global_step=counter - 1) with open(counter_path, 'w') as fp: fp.write(str(counter - 1) + '\n') def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run(tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] avg_loss = (0.0, 0.0) start_time = time.time() try: while True: if steps > 0 and counter == (counter_base + steps): save() return if (counter - 1) % save_every == 0 and counter > 1: save() if (counter - 1) % sample_every == 0 and counter > 1: generate_samples() if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) if counter % print_every == 0: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save()
def finetune(sess, dataset, steps=-1, model_name='124M', model_dir='models', combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from='latest', run_name='run1', checkpoint_dir='checkpoint', sample_every=100, sample_length=1023, sample_num=1, multi_gpu=False, save_every=1000, print_every=1, max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, optimizer='adam', overwrite=False): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ assert model_name not in [ '774M', '1558M' ] or multi_gpu, "[ERROR] Currently, a modern single GPU cannot finetune the 774M GPT-2 model or larger." SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(checkpoint_dir, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) files = [f for f in os.listdir(checkpoint_path)] for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: try: shutil.copyfile(os.path.join(model_dir, model_name, file), os.path.join(checkpoint_path, file)) except FileNotFoundError as fnf_error: print( "[ERROR] You need to download the GPT-2 model first via download_gpt2()" ) raise (fnf_error) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError( "[ERROR] Can't get samples longer than window size: %s" % hparams.n_ctx) if model_name not in ['117M', '124M']: use_memory_saving_gradients = True only_train_transformer_layers = True accumulate_gradients = 1 context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) gpus = [] if multi_gpu: gpus = get_available_gpus() output = model.model(hparams=hparams, X=context, gpus=gpus) loss = tf.reduce_mean( input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) tf_sample = sample.sample_sequence(hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40) all_vars = [ v for v in tf.compat.v1.trainable_variables() if 'model' in v.name ] train_vars = [v for v in all_vars if '/h' in v.name ] if only_train_transformer_layers else all_vars if optimizer == 'adam': opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) elif optimizer == 'sgd': opt = tf.compat.v1.train.GradientDescentOptimizer( learning_rate=learning_rate) if accumulate_gradients > 1: if use_memory_saving_gradients: exit( "[INFO] Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer(opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply) else: if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(ys=loss, xs=train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.compat.v1.summary.scalar('loss', loss) summaries = [summary_loss] summary_op = tf.summary.merge(summaries) summary_writer = tf.compat.v1.summary.FileWriter(checkpoint_path, sess.graph) saver = tf.compat.v1.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.compat.v1.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join(model_dir, model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('[INFO] Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('[INFO] Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print('[INFO] dataset has', data_sampler.total_size, 'tokens') print('[INFO] Training...') counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 counter_base = counter def save(): maketree(checkpoint_path) print('[INFO] Saving', os.path.join(checkpoint_path, 'model-{}').format(counter - 1)) saver.save(sess, os.path.join(checkpoint_path, 'model'), global_step=counter - 1) with open(counter_path, 'w') as fp: fp.write(str(counter - 1) + '\n') def generate_samples(header=True): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run(tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) if header: text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) return text def sample_batch(sample_size=1024): return [data_sampler.sample(sample_size) for _ in range(batch_size)] if overwrite and restore_from == 'latest': for file in files: if file.startswith('model') or file.startswith('events'): os.remove(os.path.join(checkpoint_path, file)) save() avg_loss = (0.0, 0.0) start_time = time.time() best_loss = 1000.00 if steps: steps = int(steps) try: while True: t = time.time() - start_time if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_op)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_op), feed_dict={context: sample_batch()}) monitoring.push_metric('loss', v_loss, counter, t) summary_writer.add_summary(v_summary, counter) if steps > 0 and counter == (counter_base + steps): save() return if (counter - 1) % save_every == 0 and counter > 1: save() sample_text = generate_samples(False) summary_sample = tf.compat.v1.summary.text( 'generated_sample', tf.convert_to_tensor(sample_text)) text = sess.run(summary_sample) summary_writer.add_summary(text, counter) monitoring.push_text('generated_text', sample_text, counter, v_loss, t) if v_loss < best_loss and v_loss > 1.00: best_loss = v_loss loss_text = generate_samples(False) loss_sample_text = "===== LOSS : " + str( best_loss) + "=====" + loss_text summary_loss_sample = tf.compat.v1.summary.text( 'best_loss_sample', tf.convert_to_tensor(loss_sample_text)) loss_text_ = sess.run(summary_loss_sample) summary_writer.add_summary(loss_text_, counter) monitoring.push_text('best_loss_text', loss_text, counter, best_loss, t) monitoring.push_metric('best_loss', best_loss, counter, t) if counter % print_every == 0: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) monitoring.push_metric('avg_loss', avg_loss[0] / avg_loss[1], counter, t) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=t, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('[STOP] interrupted') save()
def one_lr_cycle(sess, dataset, steps=10000, model_name='117M', combine=50000, batch_size=1, intial_lr=1e-10, final_lr=1, accumulate_gradients=5, restore_from='fresh', run_name='run1', max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, overwrite=False): """Does one LR half-cycle from initial to final over steps iterations using CLR algorithm https://github.com/bckenstler/CLR Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ CHECKPOINT_DIR = 'checkpoint' checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) files = [f for f in os.listdir(checkpoint_path)] for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: if file not in files: try: shutil.copyfile(os.path.join('models', model_name, file), os.path.join(checkpoint_path, file)) except FileNotFoundError as fnf_error: print( "You need to download the GPT-2 model first via download_gpt2()" ) raise (fnf_error) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if model_name != '117M': use_memory_saving_gradients = True only_train_transformer_layers = True accumulate_gradients = 1 context = tf.placeholder(tf.int32, [batch_size, None]) output = model.model(hparams=hparams, X=context) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) current_iter = 0 learning_rate = tf.placeholder(tf.float32, shape=[]) def get_lr(): cycle = np.floor(1 + current_iter / (2 * steps)) x = np.abs(current_iter / steps - 2 * cycle + 1) lr = intial_lr + (final_lr - intial_lr) * np.maximum( 0, (1 - x)) # * scale_fn(x) return lr all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name ] if only_train_transformer_layers else all_vars if accumulate_gradients > 1: if use_memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer( opt=tf.train.AdamOptimizer(learning_rate=learning_rate), var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: opt = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_log = tf.summary.FileWriter(checkpoint_path) saver = tf.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join('models', model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print('dataset has', data_sampler.total_size, 'tokens') print('Training...') counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 counter_base = counter def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] if overwrite and restore_from == 'latest': for file in files: if file.startswith('model') or file.startswith('events'): os.remove(os.path.join(checkpoint_path, file)) start_time = time.time() try: while True: if steps > 0 and counter == (counter_base + steps): return if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={ context: sample_batch(), learning_rate: get_lr() }) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={ context: sample_batch(), learning_rate: get_lr() }) summary_log.add_summary(v_summary, counter) print('[{counter} | {time:2.2f}] loss={loss:3.14f} lr={lr:2.14f}'. format(counter=counter, time=time.time() - start_time, loss=v_loss, lr=get_lr())) counter += 1 current_iter += 1 except KeyboardInterrupt: print('interrupted')
def finetune( sess, dataset, steps=-1, model_name="124M", model_dir="models", combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from="latest", run_name="run1", checkpoint_dir="checkpoint", sample_every=100, sample_length=1023, sample_num=1, multi_gpu=False, save_every=1000, print_every=1, max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, optimizer="adam", overwrite=False, ): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ # assert model_name not in ['774M', '1558M'] or multi_gpu, "Currently, a modern single GPU cannot finetune the 774M GPT-2 model or larger." SAMPLE_DIR = "samples" checkpoint_path = os.path.join(checkpoint_dir, run_name) def maketree(path): try: os.makedirs(path) except FileExistsError: pass maketree(checkpoint_path) files = [f for f in os.listdir(checkpoint_path)] for file in ["hparams.json", "encoder.json", "vocab.bpe"]: try: shutil.copyfile( os.path.join(model_dir, model_name, file), os.path.join(checkpoint_path, file), ) except FileNotFoundError as fnf_error: print( "You need to download the GPT-2 model first via download_gpt2()" ) raise (fnf_error) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, "hparams.json")) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) if model_name not in ["117M", "124M"]: use_memory_saving_gradients = True only_train_transformer_layers = True accumulate_gradients = 1 context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) gpus = [] if multi_gpu: gpus = get_available_gpus() output = model.model(hparams=hparams, X=context, gpus=gpus) loss = tf.reduce_mean( input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output["logits"][:, :-1])) tf_sample = sample.sample_sequence( hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40, ) all_vars = [ v for v in tf.compat.v1.trainable_variables() if "model" in v.name ] train_vars = ([v for v in all_vars if "/h" in v.name] if only_train_transformer_layers else all_vars) if optimizer == "adam": opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) elif optimizer == "sgd": opt = tf.compat.v1.train.GradientDescentOptimizer( learning_rate=learning_rate) if accumulate_gradients > 1: if use_memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer(opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.compat.v1.summary.scalar("loss", opt_apply) else: if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(ys=loss, xs=train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.compat.v1.summary.scalar("loss", loss) summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path) saver = tf.compat.v1.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.compat.v1.global_variables_initializer()) if restore_from == "latest": ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join(model_dir, model_name)) elif restore_from == "fresh": ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print("Loading checkpoint", ckpt) saver.restore(sess, ckpt) print("Loading dataset...") chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print("dataset has", data_sampler.total_size, "tokens") print("Training...") counter = 1 counter_path = os.path.join(checkpoint_path, "counter") if os.path.exists(counter_path) and restore_from == "latest": # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, "r") as fp: counter = int(fp.read()) + 1 counter_base = counter def save(): maketree(checkpoint_path) print("Saving", os.path.join(checkpoint_path, "model-{}").format(counter - 1)) saver.save(sess, os.path.join(checkpoint_path, "model"), global_step=counter - 1) with open(counter_path, "w") as fp: fp.write(str(counter - 1) + "\n") def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run(tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = "======== SAMPLE {} ========\n{}\n".format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, "samples-{}").format(counter), "w") as fp: fp.write("\n".join(all_text)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] if overwrite and restore_from == "latest": for file in files: if file.startswith("model") or file.startswith("events"): os.remove(os.path.join(checkpoint_path, file)) save() avg_loss = (0.0, 0.0) start_time = time.time() if steps: steps = int(steps) try: while True: if steps > 0 and counter == (counter_base + steps): save() return if (counter - 1) % save_every == 0 and counter > 1: save() if (counter - 1) % sample_every == 0 and counter > 1: generate_samples() if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) if counter % print_every == 0: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( "[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}" .format( counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1], )) counter += 1 except KeyboardInterrupt: print("interrupted") save()
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, optimizer="lamb", poly_power=1.0, start_warmup_step=0, use_memory_saving_gradients=False): """Creates an optimizer training op.""" global_step = tf.train.get_or_create_global_step() learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) # Implements linear decay of the learning rate. learning_rate = tf.train.polynomial_decay( learning_rate, global_step, num_train_steps, end_learning_rate=0.0, power=poly_power, cycle=False, ) # Implements linear warmup. I.e., if global_step - start_warmup_step < # num_warmup_steps, the learning rate will be # `(global_step - start_warmup_step)/num_warmup_steps * init_lr`. if num_warmup_steps: tf.logging.info("++++++ warmup starts at step " + str(start_warmup_step) + ", for " + str(num_warmup_steps) + " steps ++++++") global_steps_int = tf.cast(global_step, tf.int32) start_warm_int = tf.constant(start_warmup_step, dtype=tf.int32) global_steps_int = global_steps_int - start_warm_int warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) global_steps_float = tf.cast(global_steps_int, tf.float32) warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) warmup_percent_done = global_steps_float / warmup_steps_float warmup_learning_rate = init_lr * warmup_percent_done is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) learning_rate = ( 1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate # It is OK that you use this optimizer for finetuning, since this # is how the model was trained (note that the Adam m/v variables are NOT # loaded from init_checkpoint.) # It is OK to use AdamW in the finetuning even the model is trained by LAMB. # As report in the Bert pulic github, the learning rate for SQuAD 1.1 finetune # is 3e-5, 4e-5 or 5e-5. For LAMB, the users can use 3e-4, 4e-4,or 5e-4 for a # batch size of 64 in the finetune. if optimizer == "adamw": tf.logging.info("using adamw") optimizer = AdamWeightDecayOptimizer( learning_rate=learning_rate, weight_decay_rate=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], ) elif optimizer == "lamb": tf.logging.info("using lamb") optimizer = lamb_optimizer.LAMBOptimizer( learning_rate=learning_rate, weight_decay_rate=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], ) else: raise ValueError("Not supported optimizer: ", optimizer) if use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) tvars = tf.trainable_variables() if use_memory_saving_gradients: grads = memory_saving_gradients.gradients(loss, tvars) else: grads = tf.gradients(ys=loss, xs=tvars) # This is how bert was pre-trained. #(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step) # Normally the global step update is done inside of `apply_gradients`. # However, neither `AdamWeightDecayOptimizer` nor `LAMBOptimizer` do this. # But if you use a different optimizer, you should probably take this line # out. new_global_step = global_step + 1 train_op = tf.group(train_op, [global_step.assign(new_global_step)]) return train_op
opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) if accumulate_gradients > 1: if use_memory_saving_gradients: exit( 'Memory saving gradients are not implemented for gradient accumulation yet.' ) opt = AccumulatingOptimizer(opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply) else: if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(ys=loss, xs=train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.compat.v1.summary.scalar('loss', loss) # opt = custom_optimization.create_optimizer( # loss, learning_rate, num_train_steps, num_warmup_steps # ) # summary_loss = tf.compat.v1.summary.scalar('loss', loss) def maketree(path): try: os.makedirs(path)
def finetune(sess, dataset, steps=-1, model_name='124M', model_dir='models', combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from='latest', run_name='run1', checkpoint_dir='checkpoint', sample_every=100, sample_length=1023, sample_num=1, multi_gpu=False, save_every=1000, print_every=1, max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, optimizer='adam', overwrite=False, val_dataset=None, val_batch_size=2, val_batch_count=40, val_every=0 ): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(checkpoint_dir, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) files = [f for f in os.listdir(checkpoint_path)] for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: try: shutil.copyfile(os.path.join(model_dir, model_name, file), os.path.join(checkpoint_path, file)) except FileNotFoundError as fnf_error: print("You need to download the GPT-2 model first via download_gpt2()") raise(fnf_error) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError( "Can't get samples longer than window size: %s" % hparams.n_ctx) context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) gpus = [] if multi_gpu: gpus = get_available_gpus() output = model.model(hparams=hparams, X=context, gpus=gpus) loss = tf.reduce_mean( input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) # validation code if val_every > 0: val_context = tf.placeholder(tf.int32, [val_batch_size, None]) val_output = model.model(hparams=hparams, X=val_context, reuse=True) # added reuse=True val_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) val_loss_summary = tf.summary.scalar('val_loss', val_loss) tf_sample = sample.sample_sequence( hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40) all_vars = [v for v in tf.compat.v1.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name] if only_train_transformer_layers else all_vars if optimizer == 'adam': opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) elif optimizer == 'sgd': opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=learning_rate) if accumulate_gradients > 1: if use_memory_saving_gradients: exit("Memory saving gradients are not implemented for gradient accumulation yet.") opt = AccumulatingOptimizer( opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply) else: if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(ys=loss, xs=train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.compat.v1.summary.scalar('loss', loss) summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path) saver = tf.compat.v1.train.Saver( var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.compat.v1.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join(model_dir, model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint( os.path.join(model_dir, model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) # validation code if val_every > 0: if val_dataset: val_chunks = load_dataset(enc, val_dataset, combine) else: val_chunks = chunks print('dataset has', data_sampler.total_size, 'tokens') print('Training...') # validation code if val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_data_sampler = Sampler(val_chunks, seed=1) val_batches = [[val_data_sampler.sample(1024) for _ in range(val_batch_size)] for _ in range(val_batch_count)] counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 counter_base = counter def save(): maketree(checkpoint_path) print( 'Saving', os.path.join(checkpoint_path, 'model-{}').format(counter-1)) saver.save( sess, os.path.join(checkpoint_path, 'model'), global_step=counter-1) with open(counter_path, 'w') as fp: fp.write(str(counter-1) + '\n') def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run( tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) # validation code def validation(): print('Calculating validation loss...') losses = [] for batch in tqdm(val_batches): losses.append(sess.run(val_loss, feed_dict={val_context: batch})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print( '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_val_loss)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] if overwrite and restore_from == 'latest': for file in files: if file.startswith('model') or file.startswith('events'): os.remove(os.path.join(checkpoint_path, file)) save() avg_loss = (0.0, 0.0) start_time = time.time() if steps: steps = int(steps) try: while True: if steps > 0 and counter == (counter_base + steps): save() return if (counter - 1) % save_every == 0 and counter > 1: save() if (counter - 1) % sample_every == 0 and counter > 1: generate_samples() # validation code if val_every > 0 and (counter % val_every == 0 or counter == 1): validation() if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run( opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) if (counter % print_every == 0) or counter == 1: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save()
def opt_compute_gradients(self, loss, var_list=None, gate_gradients=GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None): """Compute gradients of `loss` for the variables in `var_list`. This is the first part of `minimize()`. It returns a list of (gradient, variable) pairs where "gradient" is the gradient for "variable". Note that "gradient" can be a `Tensor`, an `IndexedSlices`, or `None` if there is no gradient for the given variable. Args: loss: A Tensor containing the value to minimize or a callable taking no arguments which returns the value to minimize. When eager execution is enabled it must be a callable. var_list: Optional list or tuple of `tf.Variable` to update to minimize `loss`. Defaults to the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. gate_gradients: How to gate the computation of gradients. Can be `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. colocate_gradients_with_ops: If True, try colocating gradients with the corresponding op. grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. Returns: A list of (gradient, variable) pairs. Variable is always present, but gradient can be `None`. Raises: TypeError: If `var_list` contains anything else than `Variable` objects. ValueError: If some arguments are invalid. RuntimeError: If called with eager execution enabled and `loss` is not callable. @compatibility(eager) When eager execution is enabled, `gate_gradients`, `aggregation_method`, and `colocate_gradients_with_ops` are ignored. @end_compatibility """ if callable(loss): with backprop.GradientTape() as tape: if var_list is not None: tape.watch(var_list) loss_value = loss() # Scale loss if using a "mean" loss reduction and multiple replicas. # Have to be careful to call distribute_lib.get_loss_reduction() # *after* loss() is evaluated, so we know what loss reduction it uses. # TODO(josh11b): Test that we handle weight decay in a reasonable way. loss_value = self._scale_loss(loss_value) if var_list is None: var_list = tape.watched_variables() # TODO(jhseu): Figure out why GradientTape's gradients don't require loss # to be executed. with ops.control_dependencies([loss_value]): grads = tape.gradient(loss_value, var_list, grad_loss) return list(zip(grads, var_list)) # Non-callable/Tensor loss case if context.executing_eagerly(): raise RuntimeError( "`loss` passed to Optimizer.compute_gradients should " "be a function when eager execution is enabled.") # Scale loss if using a "mean" loss reduction and multiple replicas. loss = self._scale_loss(loss) if gate_gradients not in [ Optimizer.GATE_NONE, Optimizer.GATE_OP, Optimizer.GATE_GRAPH ]: raise ValueError( "gate_gradients must be one of: Optimizer.GATE_NONE, " "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" % gate_gradients) self._assert_valid_dtypes([loss]) if grad_loss is not None: self._assert_valid_dtypes([grad_loss]) if var_list is None: var_list = ( variables.trainable_variables() + ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) else: var_list = nest.flatten(var_list) # pylint: disable=protected-access var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS) # pylint: enable=protected-access processors = [_get_processor(v) for v in var_list] if not var_list: raise ValueError("No variables to optimize.") var_refs = [p.target() for p in processors] grads = memory_saving_gradients.gradients( loss, var_refs, grad_ys=grad_loss, gate_gradients=(gate_gradients == Optimizer.GATE_OP), aggregation_method=aggregation_method, colocate_gradients_with_ops=colocate_gradients_with_ops) if gate_gradients == Optimizer.GATE_GRAPH: grads = control_flow_ops.tuple(grads) grads_and_vars = list(zip(grads, var_list)) self._assert_valid_dtypes([ v for g, v in grads_and_vars if g is not None and v.dtype != dtypes.resource ]) return grads_and_vars
def finetune(sess, dataset, steps=-1, model_name='124M', model_dir='models', combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from='latest', run_name='run1', checkpoint_dir='checkpoint', sample_every=100, sample_length=1023, sample_num=1, save_every=1000, print_every=1, max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, optimizer='adam', overwrite=False, mixed_precision=False): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(checkpoint_dir, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) files = [f for f in os.listdir(checkpoint_path)] for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: try: shutil.copyfile(os.path.join(model_dir, model_name, file), os.path.join(checkpoint_path, file)) except FileNotFoundError as fnf_error: print( "You need to download the GPT-2 model first via download_gpt2()" ) raise (fnf_error) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) if model_name not in ['117M', '124M']: use_memory_saving_gradients = True only_train_transformer_layers = True accumulate_gradients = 1 context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) output = model.model(hparams=hparams, X=context) loss = tf.reduce_mean( input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) tf_sample = sample.sample_sequence(hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40) all_vars = [ v for v in tf.compat.v1.trainable_variables() if 'model' in v.name ] train_vars = [v for v in all_vars if '/h' in v.name ] if only_train_transformer_layers else all_vars if optimizer == 'adam': opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) elif optimizer == 'sgd': opt = tf.compat.v1.train.GradientDescentOptimizer( learning_rate=learning_rate) elif optimizer == 'sm3': opt = SM3Optimizer(learning_rate=learning_rate, momentum=0.9) elif optimizer == 'adafactor': opt = AdafactorOptimizer(learning_rate=learning_rate) def mp_check_tf_version(): # check TensorFlow >= 1.14 tf_version_list = tf.__version__.split(".") if int(tf_version_list[0]) < 2: return int(tf_version_list[1]) >= 14 def mp_check_tensor_core_gpu_present(): from tensorflow.python.client import device_lib # check Compute Capability >= 7.0 local_device_protos = device_lib.list_local_devices() for line in local_device_protos: if "compute capability" in str(line): compute_capability = float( line.physical_device_desc.split("compute capability: ") [-1]) if compute_capability >= 7.0: return True if mixed_precision and mp_check_tf_version( ) and mp_check_tensor_core_gpu_present(): if isinstance(opt, tf.keras.optimizers.Optimizer) or isinstance( opt, tf.compat.v1.train.Optimizer): opt = tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite( opt) else: mixed_precision = False else: mixed_precision = False print('Mixed precision enabled: %s' % mixed_precision) # Fix warning: https://github.com/tensorflow/tensorflow/blob/233d3d/tensorflow/python/training/experimental/mixed_precision.py#L355-L357 if sess is None: sess = start_tf_sess() if accumulate_gradients > 1: if use_memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer(opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply) else: if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(ys=loss, xs=train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.compat.v1.summary.scalar('loss', loss) summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path) saver = tf.compat.v1.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.compat.v1.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join(model_dir, model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print('dataset has', data_sampler.total_size, 'tokens') print('Training...') counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 counter_base = counter def save(): maketree(checkpoint_path) print('Saving', os.path.join(checkpoint_path, 'model-{}').format(counter - 1)) saver.save(sess, os.path.join(checkpoint_path, 'model'), global_step=counter - 1) with open(counter_path, 'w') as fp: fp.write(str(counter - 1) + '\n') def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run(tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] if overwrite and restore_from == 'latest': for file in files: if file.startswith('model') or file.startswith('events'): os.remove(os.path.join(checkpoint_path, file)) save() avg_loss = (0.0, 0.0) start_time = time.time() if steps: steps = int(steps) try: while True: if steps > 0 and counter == (counter_base + steps): save() return if (counter - 1) % save_every == 0 and counter > 1: save() if (counter - 1) % sample_every == 0 and counter > 1: generate_samples() if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) if counter % print_every == 0: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save()