def main(): args = parser.parse_args() enc = encoder.get_encoder(args.model_name) hparams = model.default_hparams() with open(os.path.join('models', args.model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if args.sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: context = tf.placeholder(tf.int32, [args.batch_size, None]) output = model.model(hparams=hparams, X=context) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) tf_sample = sample.sample_sequence(hparams=hparams, length=args.sample_length, context=context, batch_size=args.batch_size, temperature=1.0, top_k=40) train_vars = [v for v in tf.trainable_variables() if 'model' in v.name] if args.accumulate_gradients > 1: opt = AccumulatingOptimizer( opt=tf.train.AdamOptimizer(learning_rate=args.learning_rate), var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: opt_apply = tf.train.AdamOptimizer( learning_rate=args.learning_rate).minimize(loss, var_list=train_vars) summary_loss = tf.summary.scalar('loss', loss) summary_log = tf.summary.FileWriter( os.path.join(CHECKPOINT_DIR, args.run_name)) saver = tf.train.Saver(var_list=train_vars, max_to_keep=5, keep_checkpoint_every_n_hours=2) sess.run(tf.global_variables_initializer()) if args.restore_from == 'latest': ckpt = tf.train.latest_checkpoint( os.path.join(CHECKPOINT_DIR, args.run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join('models', args.model_name)) elif args.restore_from == 'fresh': ckpt = tf.train.latest_checkpoint( os.path.join('models', args.model_name)) else: ckpt = tf.train.latest_checkpoint(args.restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, args.dataset, args.combine) data_sampler = Sampler(chunks) print('dataset has', data_sampler.total_size, 'tokens') print('Training...') counter = 1 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) saver.save(sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') print('uploading to Google Drive...') gdriveUp = os.popen( 'cp -r /content/gpt-2/checkpoint/ /content/drive/My\\ Drive/' + GDRIVE_DIR).read() print(gdriveUp) def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < args.sample_num: out = sess.run( tf_sample, feed_dict={context: args.batch_size * [context_tokens]}) for i in range(min(args.sample_num - index, args.batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, args.run_name)) with open( os.path.join(SAMPLE_DIR, args.run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(args.batch_size)] avg_loss = (0.0, 0.0) start_time = time.time() try: while True: if counter % args.save_every == 0: save() if counter % args.sample_every == 0: generate_samples() if args.accumulate_gradients > 1: sess.run(opt_reset) for _ in range(args.accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save()
def main(): args = parser.parse_args() enc = encoder.get_encoder(args.model_name) hparams = model.default_hparams() with open(os.path.join('models', args.model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if args.sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) if args.model_name == '355M': args.memory_saving_gradients = True if args.optimizer == 'adam': args.only_train_transformer_layers = True config = tf.ConfigProto() config.gpu_options.allow_growth = True config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF with tf.Session(config=config) as sess: context = tf.placeholder(tf.int32, [args.batch_size, None]) context_in = randomize(context, hparams, args.noise) output = model.model(hparams=hparams, X=context_in) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) if args.val_every > 0: val_context = tf.placeholder(tf.int32, [args.val_batch_size, None]) val_output = model.model(hparams=hparams, X=val_context) val_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) val_loss_summary = tf.summary.scalar('val_loss', val_loss) tf_sample = sample.sample_sequence(hparams=hparams, length=args.sample_length, context=context, batch_size=args.batch_size, temperature=1.0, top_k=args.top_k, top_p=args.top_p) all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name ] if args.only_train_transformer_layers else all_vars if args.optimizer == 'adam': opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) elif args.optimizer == 'sgd': opt = tf.train.GradientDescentOptimizer( learning_rate=args.learning_rate) else: exit('Bad optimizer:', args.optimizer) if args.accumulate_gradients > 1: if args.memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer(opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: if args.memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_lr = tf.summary.scalar('learning_rate', args.learning_rate) summaries = tf.summary.merge([summary_lr, summary_loss]) summary_log = tf.summary.FileWriter( os.path.join(CHECKPOINT_DIR, args.run_name)) saver = tf.train.Saver(var_list=all_vars, max_to_keep=5, keep_checkpoint_every_n_hours=2) sess.run(tf.global_variables_initializer()) if args.restore_from == 'latest': ckpt = tf.train.latest_checkpoint( os.path.join(CHECKPOINT_DIR, args.run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join('models', args.model_name)) elif args.restore_from == 'fresh': ckpt = tf.train.latest_checkpoint( os.path.join('models', args.model_name)) else: ckpt = tf.train.latest_checkpoint(args.restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, args.dataset, args.combine, encoding=args.encoding) data_sampler = Sampler(chunks) if args.val_every > 0: if args.val_dataset: val_chunks = load_dataset(enc, args.val_dataset, args.combine, encoding=args.encoding) else: val_chunks = chunks print('dataset has', data_sampler.total_size, 'tokens') print('Training...') if args.val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_data_sampler = Sampler(val_chunks, seed=1) val_batches = [[ val_data_sampler.sample(1024) for _ in range(args.val_batch_size) ] for _ in range(args.val_batch_count)] counter = 1 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) saver.save(sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') def generate_samples(): print('Generating samples...') context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < args.sample_num: out = sess.run( tf_sample, feed_dict={context: args.batch_size * [context_tokens]}) for i in range(min(args.sample_num - index, args.batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text.encode('utf8')) maketree(os.path.join(SAMPLE_DIR, args.run_name)) with open(os.path.join(SAMPLE_DIR, args.run_name, 'samples-{}').format(counter), 'w', encoding=args.encoding) as fp: fp.write('\n'.join(all_text)) def validation(): print('Calculating validation loss...') losses = [] for batch in tqdm.tqdm(val_batches): losses.append( sess.run(val_loss, feed_dict={val_context: batch})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'. format(counter=counter, time=time.time() - start_time, loss=v_val_loss)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(args.batch_size)] avg_loss = (0.0, 0.0) start_time = time.time() try: while True: if counter % args.save_every == 0: save() if counter % args.sample_every == 0: generate_samples() if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1): validation() if args.accumulate_gradients > 1: sess.run(opt_reset) for _ in range(args.accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summaries)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summaries), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save()
def main(): args = parser.parse_args() try: logdir = os.path.join(CHECKPOINT_DIR, args.run_name) with open('logdir.txt', 'w') as z: z.write(logdir) except: pass enc = get_encoder(model_name) hparams = model.default_hparams() with open(os.path.join(model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if args.sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) if args.model_name == '345M': args.memory_saving_gradients = True args.only_train_transformer_layers = True config = tf.ConfigProto() config.gpu_options.allow_growth = True config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF with tf.Session(config=config) as sess: context = tf.placeholder(tf.int32, [args.batch_size, None]) output = model.model(hparams=hparams, X=context) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) if args.val_every > 0: val_context = tf.placeholder(tf.int32, [args.val_batch_size, None]) val_output = model.model(hparams=hparams, X=val_context) val_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) val_loss_summary = tf.summary.scalar('val_loss', val_loss) tf_sample = sample.sample_sequence(hparams=hparams, length=args.sample_length, context=context, batch_size=args.batch_size, temperature=1.0, top_k=40) all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name ] if args.only_train_transformer_layers else all_vars if args.accumulate_gradients > 1: if args.memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer( opt=tf.train.AdamOptimizer(learning_rate=args.learning_rate), var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) if args.memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_log = tf.summary.FileWriter( os.path.join(CHECKPOINT_DIR, args.run_name)) saver = tf.train.Saver(var_list=all_vars, max_to_keep=5, keep_checkpoint_every_n_hours=2) sess.run(tf.global_variables_initializer()) if args.restore_from == 'latest': ckpt = tf.train.latest_checkpoint( os.path.join(CHECKPOINT_DIR, args.run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint(os.path.join(model_name)) elif args.restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join(model_name)) else: ckpt = tf.train.latest_checkpoint(args.restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) #print('Loading dataset...') #chunks = load_dataset(enc, args.dataset, args.combine) #data_sampler = Sampler(chunks) print('Loading train dataset...') from_name, ques_name, to_name = name_parts(args.dataset) trn_chunks_from = load_dataset( enc, from_name, args.combine) #if args.dataset else chunks #trn_chunks_ques = load_dataset(enc, ques_name, args.combine) if args.dataset else chunks trn_chunks_to = load_dataset( enc, to_name, args.combine) #if args.dataset else chunks skip_delimeter = True char = '\t' trn_data_sampler_from = SamplerVal(trn_chunks_from, enc, char=char, skip_delimeter=skip_delimeter) #trn_data_sampler_ques = SamplerVal(trn_chunks_ques, enc, char=char, skip_delimeter=skip_delimeter) trn_data_sampler_to = SamplerVal(trn_chunks_to, enc, char=char, skip_delimeter=skip_delimeter) len_v = 0 data_sampler = [] for i in range(trn_data_sampler_from.total_size): v = ( #enc.encode('\nQ: ') + trn_data_sampler_from.get(i) + #enc.encode('. \nA: ') + trn_data_sampler_to.get(i) # + #enc.encode('. ') ) v = v[:HIDDEN_SIZE - 1] len_v += len(v) #data_sampler.extend(v) ## data_sampler.append(v) pass if len_v < HIDDEN_SIZE: mult = HIDDEN_SIZE // len_v + 1 for i in range(mult): x = data_sampler[:] data_sampler.extend(x) data_sampler = Sampler([np.array(data_sampler)]) #if not args.train_special and len_v >= HIDDEN_SIZE: # data_sampler = Sampler([np.array(data_sampler)]) if args.val_every > 0 and False: val_chunks = load_dataset( enc, args.val_dataset, args.combine) if args.val_dataset else chunks if not isinstance(data_sampler, list): print('dataset has', data_sampler.total_size, 'tokens') print('Training...') if args.val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_data_sampler = Sampler(val_chunks, seed=1) val_batches = [[ val_data_sampler.sample(1024) for _ in range(args.val_batch_size) ] for _ in range(args.val_batch_count)] counter = 1 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) saver.save(sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') #print(model_name, 'mn') GPT2_DIR_X = model_name cd = CHECKPOINT_DIR + "/" + args.run_name if not os.path.isfile(cd + '/' + 'encoder.json'): os.system("cp " + GPT2_DIR_X + '/' + 'encoder.json ' + cd + '/.') os.system('cp ' + GPT2_DIR_X + "/" + 'vocab.bpe ' + cd + '/.') def generate_samples(): print('Generating samples...') #context_tokens = data_sampler.sample(1) #context_tokens = data_sampler[0] context_tokens = trn_data_sampler_from.get( random.randint(0, trn_data_sampler_from.total_size)) #print(enc.decode(context_tokens), len(context_tokens)) #print(args.batch_size * [context_tokens]) all_text = [] index = 0 while index < args.sample_num: out = sess.run( tf_sample, feed_dict={context: args.batch_size * [context_tokens]}) for i in range(min(args.sample_num - index, args.batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, args.run_name)) with open( os.path.join(SAMPLE_DIR, args.run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) def validation(): print('Calculating validation loss...') losses = [] for batch in tqdm.tqdm(val_batches): losses.append( sess.run(val_loss, feed_dict={val_context: batch})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'. format(counter=counter, time=time.time() - start_time, loss=v_val_loss)) def sample_batch(): #z = [data_sampler.sample(1024) for _ in range(args.batch_size)] #print(len(data_sampler)) #print(len(data_sampler[0])) z = [data_sampler[random.randint(0, args.batch_size)]] #print(enc.decode(z[0])) #print(z[1],'\n1' ,z[2],'\n2' ,z[3] ,len(data_sampler[0])) #exit() return z avg_loss = (0.0, 0.0) start_time = time.time() try: while counter != args.stop_after: if counter % args.save_every == 0: save() if counter % args.sample_every == 0: generate_samples() pass if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1): validation() if args.accumulate_gradients > 1: sess.run(opt_reset) for _ in range(args.accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('\ninterrupted') finally: save()
def main(): args = parser.parse_args() enc = encoder.get_encoder(args.model_name) hparams = model.default_hparams() hparams.res_dropout = args.dropout hparams.attn_dropout = args.dropout epsilon = -1e10 if args.dtype == 'float32': hparams.dtype = tf.float32 elif args.dtype == 'float16': hparams.dtype = tf.float16 epsilon = -65500 elif args.dtype == 'bfloat16': hparams.dtype = tf.bfloat16 epsilon = -65500 else: print('Unknown dtype', args.dtype) if args.float16: hparams.dtype = tf.bfloat16 epsilon = -65500 with open(os.path.join('models', args.model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if args.n_ctx >= 0: hparams.n_ctx=args.n_ctx if args.n_embd >= 0: hparams.n_embd=args.n_embd if args.n_head >= 0: hparams.n_head=args.n_head if args.n_layer >= 0: hparams.n_layer=args.n_layer if args.sample_length < 0: args.sample_length = hparams.n_ctx - 1 if args.sample_length > hparams.n_ctx: raise ValueError( "Can't get samples longer than window size: %s" % hparams.n_ctx) if args.sample_ctx < 0: args.sample_ctx = hparams.n_ctx if args.model_name == '345M': args.memory_saving_gradients = True if args.optimizer == 'adam': args.only_train_transformer_layers = True config = tf.ConfigProto() if args.allow_growth: config.gpu_options.allow_growth = True if args.disable_layout_optimizer: config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF with tflex.Session(config=config, init_tpu=args.init_tpu) as sess: context = tf.placeholder(tf.int32, [args.batch_size, None]) context_in = randomize(context, hparams, args.noise) output = model.model(hparams=hparams, X=context_in) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) if args.val_every > 0: val_context = tf.placeholder(tf.int32, [args.val_batch_size, None]) val_output = model.model(hparams=hparams, X=val_context) val_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) val_loss_summary = tf.summary.scalar('val_loss', val_loss) tf_sample = sample.sample_sequence( hparams=hparams, length=args.sample_length, context=context, batch_size=args.batch_size, temperature=1.0, top_k=args.top_k, top_p=args.top_p, epsilon=epsilon) all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars parameter_count = sum([np.prod(v.shape.as_list()) for v in train_vars]) print("This model is using %d parameters (%.2fM)" % (parameter_count, parameter_count/(1024.0*1024.0))) with tf.variable_scope(tf.get_variable_scope().name, reuse=tf.AUTO_REUSE): global_step = tflex.get_variable('global_step') or tf.get_variable('global_step', shape=(), dtype=tf.int32, trainable=False) current_step = args.learning_rate_initial_step global_step.load(current_step, session=sess) if args.learning_rate_cos: lr = tflex_sgdr.sgdr_decay_with_warmup(args.learning_rate, global_step, warmup_steps=args.learning_rate_warmup, initial_period_steps=args.learning_rate_period, learning_rate_min=args.learning_rate_min) else: lr = tflex.get_variable('learn_rate') or tf.get_variable('learn_rate', shape=(), dtype=tf.float32, trainable=False) lr.load(args.learning_rate, session=sess) def update_lr(rate=None, step=None): if not args.learning_rate_cos: if step is None: step = global_step.eval(session=sess) if rate is None: rate = args.learning_rate if callable(rate): rate = rate(step) lr.load(rate, session=sess) return lr.eval(session=sess) @tflex.register_command def set_learning_rate(): print("Current learn rate: %0.8f" % update_lr()) print("New learn rate?") rate = input('') if not rate: print("Empty input; not changing anything.") else: try: rate = float(rate) except: print("Invalid input; must be a float") print("Setting learn rate to %0.8f" % rate) args.learning_rate = rate if args.optimizer == 'adam': opt = tf.train.AdamOptimizer(learning_rate=lr) elif args.optimizer == 'sgd': opt = tf.train.GradientDescentOptimizer(learning_rate=lr) elif args.optimizer == 'ada': import tensor2tensor.utils.optimize from tensor2tensor.utils import hparam import tensor2tensor.models.research from tensor2tensor.utils import registry ada_hparams = registry.hparams('afx_mimic_adam') ada_hparams.optimizer_adafactor_beta1 = 0.0 ada_hparams.optimizer_adafactor_factored = True opt = tensor2tensor.utils.optimize.adafactor(learning_rate=lr, hparams=ada_hparams) else: exit('Bad optimizer:', args.optimizer) #if tpu_addr: # # https://pulsejet.github.io/blog/posts/tpu-without-estimator/ # from tensorflow.contrib.tpu.python.tpu import tpu_function # tpu_function.get_tpu_context().set_number_of_shards(8) # opt = tf.contrib.tpu.CrossShardOptimizer(opt) if args.accumulate_gradients > 1: if args.memory_saving_gradients: exit("Memory saving gradients are not implemented for gradient accumulation yet.") opt = AccumulatingOptimizer( opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: if args.memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_lr = tf.summary.scalar('learning_rate', lr) summaries = tf.summary.merge([summary_lr, summary_loss]) summary_log = tf.summary.FileWriter( os.path.join(CHECKPOINT_DIR, args.run_name)) if args.save_graph: summary_log.add_graph(tf.get_default_graph()) saver = tflex.Saver( var_list=all_vars, max_to_keep=args.max_to_keep, keep_checkpoint_every_n_hours=100000, reshape=args.truncate_weights) sess.run(tf.global_variables_initializer()) if args.restore_from == 'latest': ckpt = tflex.latest_checkpoint( os.path.join(CHECKPOINT_DIR, args.run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tflex.latest_checkpoint( os.path.join('models', args.model_name)) elif args.restore_from == 'fresh': ckpt = tflex.latest_checkpoint( os.path.join('models', args.model_name)) else: ckpt = tflex.latest_checkpoint(args.restore_from) print('Loading snapshot %s...' % ckpt) t0 = time.time() if not args.fresh_model: saver.restore(sess, ckpt) t1 = time.time() print('Loaded in %f seconds' % (t1 - t0)) def make_sampler(dataset, enc, seed, combine): if os.path.isdir(dataset) or dataset.endswith('.npz'): chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks, seed=seed) print('dataset has', data_sampler.total_size, 'tokens', len(chunks), 'chunks') else: data_sampler = TextSampler(dataset, enc, seed=seed) return data_sampler print('Loading dataset...') seed = None if args.seed < 0 else args.seed data_sampler = make_sampler(dataset=args.dataset, enc=enc, seed=seed, combine=args.combine) if args.val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_dataset = args.val_dataset if args.val_dataset else args.dataset val_data_sampler = make_sampler(dataset=val_dataset, enc=enc, seed=1, combine=args.combine) val_batches = [[val_data_sampler.sample(hparams.n_ctx) for _ in range(args.val_batch_size)] for _ in range(args.val_batch_count)] print('Training...') counter = 1 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 @tflex.register_command def get_tarfile_name(checkpoint_folder): """Converts a folder path into a filename for a .tar archive""" tarfile_name = checkpoint_folder.replace(os.path.sep, '_') + '.tar' return tarfile_name def copy_checkpoint_to_gdrive(run_name='run1', copy_folder=False): """Copies the checkpoint folder to a mounted Google Drive.""" #is_mounted() checkpoint_folder = os.path.join('checkpoint', run_name) if copy_folder: shutil.copytree(checkpoint_folder, "/content/drive/My Drive/" + checkpoint_folder) else: file_path = get_tarfile_name(checkpoint_folder) # Reference: https://stackoverflow.com/a/17081026 with tarfile.open(file_path, 'w') as tar: tar.add(checkpoint_folder) shutil.copyfile(file_path, "/content/drive/My Drive/" + file_path) @tflex.register_command def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) t0 = time.time() saver.save( sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) t1 = time.time() print('Saved in %f seconds' % (t1 - t0)) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') #copy_checkpoint_to_gdrive() @tflex.register_command def generate_samples(): print('Generating samples...') context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < args.sample_num: out = sess.run( tf_sample, feed_dict={context: args.batch_size * [context_tokens]}) for i in range(min(args.sample_num - index, args.batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) print(text) all_text.append(text) index += 1 maketree(os.path.join(SAMPLE_DIR, args.run_name)) with open( os.path.join(SAMPLE_DIR, args.run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) @tflex.register_command def validation(): if args.val_every <= 0: return print('Calculating validation loss...') losses = [] for batch in tqdm.tqdm(val_batches): losses.append(sess.run(val_loss, feed_dict={val_context: batch})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print( '{stamp} [{counter} | {time:2.4f}] validation loss = {loss:2.4f}' .format( stamp=timestamp(), counter=counter, time=time.time() - start_time, loss=v_val_loss)) start_time = time.time() def elapsed(): return time.time() - start_time def say(msg): print('{stamp} [{counter} | {time:2.4f}] {msg}'.format(counter=counter, time=elapsed(), msg=msg, stamp=timestamp())) def sample_batch(): #return [data_sampler.sample(args.sample_ctx) for _ in range(args.batch_size)] #say('Sampling batch...') r = [] times = [] for _ in range(args.batch_size): start = time.time() sample = data_sampler.sample(args.sample_ctx) end = time.time() elapsed = (end - start) r += [sample] times += [elapsed] total = sum(times) avg = total / len(times) #say('Sampled %d batches in %.4f seconds (avg per batch: %.4f)' % (args.batch_size, total, avg)) return r prev_time = time.time() avg_loss = (0.0, 0.0) if args.debug_before_training: import pdb pdb.set_trace() last_saved_time = elapsed() while True: try: now = elapsed() if args.save_time > 0 and (((now - last_saved_time) / 60.0) >= args.save_time): save() last_saved_time = now elif args.save_every > 0 and (counter % args.save_every == 0): save() if counter % args.sample_every == 0: generate_samples() if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1): validation() v_rate = update_lr() if args.accumulate_gradients > 1: #say('Running opt_reset...') sess.run(opt_reset) for _ in range(args.accumulate_gradients): batch = sample_batch() say('Running opt_compute...') sess.run(opt_compute, feed_dict={context: batch}) say('Running opt_apply...') (v_loss, v_summary) = sess.run((opt_apply, summaries)) else: batch = sample_batch() say('Running opt_apply...') (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summaries), feed_dict={context: batch}) if args.float16: v_loss = tf.to_float(v_loss).eval() summary_log.add_summary(v_summary, counter) summary_log.flush() avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) now = time.time() print('{stamp} [{counter} | {time:2.4f} | {delta:2.2f}s | {ops:2.6f}tokens/s] loss={loss:2.4f} avg={avg:2.4f} rate={rate:0.7f} step={step}' .format( stamp=timestamp(), counter=counter, time=now - start_time, delta=now - prev_time, ops=args.sample_ctx * args.batch_size / (now - prev_time), rate=v_rate, loss=v_loss, avg=avg_loss[0] / avg_loss[1], step=current_step, )) counter += 1 current_step += 1 global_step.load(current_step, session=sess) tflex.check_commands_with_args( session=sess, stamp=timestamp(), counter=counter, time=now - start_time, delta=now - prev_time, ops=args.batch_size / (now - prev_time), rate=v_rate, loss=v_loss, avg=avg_loss[0] / avg_loss[1], avg_loss=avg_loss, step=current_step, train_vars=train_vars, all_vars=all_vars, args=args, data_sampler=data_sampler, ckpt=ckpt, saver=saver, ) if tflex.should_quit(): break prev_time = now if args.debug_print_all_vars: print('all variables:') print('name/shape/parameter_count') param_count = 0 for x in tf.all_variables(): shape = x.shape.as_list() count = np.prod(shape) print(x.name, shape, count) param_count += count print('Total parameters:', param_count) args.debug_print_all_vars = False if args.debug_print_trainable_vars: print('trainable variables:') print('name/shape/parameter_count') param_count = 0 for x in tf.trainable_variables(): shape = x.shape.as_list() count = np.prod(shape) print(x.name, shape, count) param_count += count print('Total parameters:', param_count) args.debug_print_trainable_vars = False except KeyboardInterrupt: print('interrupted') if args.save_on_ctrlc: save() if args.debug_on_ctrlc: import pdb pdb.set_trace() else: break
def train(dataset, model_in_path, model_out_path, model_name='117M', steps=1000, combine=50000, batch_size=1, learning_rate=0.00002, accumulate_gradients=1, memory_saving_gradients=False, only_train_transformer_layers=False, optimizer='adam', noise=0.0, top_k=40, top_p=0.0, restore_from='latest', sample_every=100, sample_length=1023, sample_num=1, save_every=1000, val_dataset=None): # Reset the TF computation graph tf.reset_default_graph() # Get the checkpoint and sample directories #checkpoint_dir = os.path.dirname(model_path) #sample_dir = checkpoint_dir #run_name = os.path.basename(model_path) # Load the encoder enc = get_encoder(model_in_path) hparams = model.default_hparams() with open(os.path.join(model_in_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) # Size matters if model_name == '345M': memory_saving_gradients = True if optimizer == 'adam': only_train_transformer_layers = True # Configure TF config = tf.ConfigProto() config.gpu_options.allow_growth = True config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF # Start the session with tf.Session(config=config) as sess: context = tf.placeholder(tf.int32, [batch_size, None]) context_in = randomize(context, hparams, noise) output = model.model(hparams=hparams, X=context_in) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) tf_sample = sample.sample_sequence(hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=top_k, top_p=top_p) all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name ] if only_train_transformer_layers else all_vars if optimizer == 'adam': opt = tf.train.AdamOptimizer(learning_rate=learning_rate) elif optimizer == 'sgd': opt = tf.train.GradientDescentOptimizer( learning_rate=learning_rate) else: exit('Bad optimizer:', optimizer) if accumulate_gradients > 1: if memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer(opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: if memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_lr = tf.summary.scalar('learning_rate', learning_rate) summaries = tf.summary.merge([summary_lr, summary_loss]) summary_log = tf.summary.FileWriter( #os.path.join(checkpoint_dir, run_name) model_out_path) saver = tf.train.Saver(var_list=all_vars, max_to_keep=1) sess.run(tf.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint( #os.path.join(checkpoint_dir, run_name) model_in_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( model_in_path) #os.path.join('models', model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint( model_in_path) #os.path.join('models', model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print('dataset has', data_sampler.total_size, 'tokens') print('Training...') counter = 1 counter_path = os.path.join( model_in_path, 'counter') #os.path.join(checkpoint_dir, run_name, 'counter') if restore_from == 'latest' and os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 def save(): #maketree(os.path.join(checkpoint_dir, run_name)) maketree(model_out_path) print( 'Saving', #os.path.join(checkpoint_dir, run_name, 'model-{}').format(counter) os.path.join(model_out_path, 'model-{}').format(counter)) saver.save( sess, #os.path.join(checkpoint_dir, run_name, 'model'), os.path.join(model_out_path, 'model'), global_step=counter) with open(os.path.join(model_out_path, 'counter'), 'w') as fp: fp.write(str(counter) + '\n') def generate_samples(): print('Generating samples...') context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run( tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) #maketree(os.path.join(sample_dir, run_name)) maketree(model_out_path) with open( os.path.join(model_out_path, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] avg_loss = (0.0, 0.0) start_time = time.time() stop = steps + counter try: while counter < stop + 1: if counter % save_every == 0: save() ''' if counter % sample_every == 0: generate_samples() ''' if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summaries)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summaries), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 print('done!') save() except KeyboardInterrupt: print('interrupted') save()
def main(): enc = encoder.get_encoder(args.model_name) hparams = model.default_hparams() hparams.batch_size=args.batch_size hparams.seq_len=args.seq_len ##data_path args.train_data_path=args.data_dir+args.dataset+'/train.txt' args.eval_data_path=args.data_dir+args.dataset+'/dev.txt' args.test_data_path=args.data_dir+args.dataset+'/test.txt' args.eval_data_path=args.test_data_path ###Test mode only! args.gpt_save_path=args.gpt_save_dir+args.dataset+'/' args.dis_save_path=args.dis_save_dir+args.dataset+'/' args.gpt_sample_dir2=args.gpt_sample_dir+args.dataset+'/' args.dis_sample_dir2=args.dis_sample_dir+args.dataset+'/' args.log_path=args.log_dir+args.dataset+'/' maketree(args.gpt_save_dir) maketree(args.dis_save_dir) maketree(args.gpt_save_path) maketree(args.dis_save_path) maketree(args.gpt_sample_dir) maketree(args.dis_sample_dir) maketree(args.gpt_sample_dir2) maketree(args.dis_sample_dir2) maketree(args.log_dir) maketree(args.log_path) with open(os.path.join('models', args.model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if args.sample_length > hparams.n_ctx: raise ValueError( "Can't get samples longer than window size: %s" % hparams.n_ctx) if args.model_name == '345M': args.memory_saving_gradients = True if args.optimizer == 'adam': args.only_train_transformer_layers = True config = tf.ConfigProto() config.gpu_options.allow_growth = True config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF with tf.Session(config=config) as sess: scope_discri='distri' def get_dis_logit_and_prob_single_step(context, scope): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): context=tf.reshape(context, [-1, args.seq_len]) emb=tf.get_variable(name='emb', initializer=tf.random.normal([hparams.n_vocab, 32], 0, 0.02)) context_emb=tf.nn.embedding_lookup(emb, context) logit=dis(context_emb, scope=scope_discri) prob=tf.sigmoid(logit+1e-7) return logit, prob def get_dis_logit_and_prob(context, context_len, scope): ##Pay attention to context_len here. temporary changes!!!!!!!!!!!!!!!!!!! context_mask=(1-tf.sequence_mask(context_len-1, args.seq_len-1, dtype=tf.float32))*1e3 context_mask2=tf.sequence_mask(context_len-1, args.seq_len-1, dtype=tf.float32) ones=tf.ones(shape=[tf.shape(context_len)[0], args.seq_len], dtype=tf.int32)*enc.encoder['<|endoftext|>'] input_tensor_list=[] for i in range(1, args.seq_len): input_tensor_list.append(tf.concat([context[:, :i+1], ones[:,i+1:]], axis=1)) input_tensor=tf.concat(input_tensor_list, axis=0) log_prob, _=get_dis_logit_and_prob_single_step(input_tensor, scope=scope) log_prob=tf.transpose(tf.reshape(log_prob, [args.seq_len-1, -1])) log_prob+=tf.cast(context_mask, tf.float32) log_prob_min=tf.reduce_min(log_prob, axis=1) prob_min=tf.exp(log_prob_min) return log_prob_min, prob_min, log_prob ##Build discriminator def build_dis_layer(scope): context_pos_discri = tf.placeholder(tf.int32, [None, args.seq_len]) context_pos_discri_len = tf.placeholder(tf.int32, [None]) context_neg_discri = tf.placeholder(tf.int32, [None, args.seq_len]) context_neg_discri_len = tf.placeholder(tf.int32, [None]) label_pos_discri=tf.ones([tf.shape(context_pos_discri_len)[0]], dtype=tf.float32) label_neg_discri=tf.zeros([tf.shape(context_neg_discri_len)[0]], dtype=tf.float32) logit_pos_discri, prob_pos_discri, mask=get_dis_logit_and_prob(context_pos_discri, context_pos_discri_len, scope=scope) logit_neg_discri, _, _=get_dis_logit_and_prob(context_neg_discri, context_neg_discri_len, scope=scope) loss_pre_pos_discri=tf.nn.sigmoid_cross_entropy_with_logits(labels=label_pos_discri, logits=logit_pos_discri) loss_pos_discri=tf.reduce_mean(loss_pre_pos_discri) loss_pre_neg_discri=tf.nn.sigmoid_cross_entropy_with_logits(labels=label_neg_discri, logits=logit_neg_discri) loss_neg_discri=tf.reduce_mean(loss_pre_neg_discri) loss_discri=(loss_pos_discri*args.pos_loss_weight+loss_neg_discri)/(1+args.pos_loss_weight) train_var_list_discri=[x for x in tf.global_variables() if scope in x.name] train_op_discri=tf.train.AdamOptimizer().minimize(loss_discri, var_list=train_var_list_discri) var_list_discri=[x for x in tf.global_variables() if scope in x.name] initializer_discri=tf.variables_initializer(var_list_discri) saver_discri=tf.train.Saver(var_list=var_list_discri, max_to_keep=1) print('discri: {} build succeed!'.format(scope)) return context_pos_discri,context_pos_discri_len, context_neg_discri,context_neg_discri_len, loss_pos_discri, loss_neg_discri, loss_discri, train_op_discri, initializer_discri, saver_discri, prob_pos_discri, mask, logit_pos_discri class dis_class: def __init__(self, layer_num=1, scope=scope_discri): self.model=[] self.dis=np.zeros([layer_num], dtype=np.float32) print(layer_num) for i in range(layer_num): layer={'scope': scope+str(i)} layer['context_pos_discri'],layer['context_pos_discri_len'], layer['context_neg_discri'],layer['context_neg_discri_len'], layer['loss_pos_discri'], layer['loss_neg_discri'], layer['loss_discri'], layer['train_op_discri'], layer['initializer_discri'], layer['saver_discri'], layer['prob_pos_discri'], layer['mask'], layer['logit_pos_discri'] = build_dis_layer(scope+str(i)) self.model.append(layer) def prob(self, context, context_len, layer=-1): if layer==-1: layer=len(self.model) prob_final=tf.ones(tf.shape(context)[0], dtype=tf.float32) for i in range(layer): item=self.model[i] scope=item['scope'] _, prob, _=get_dis_logit_and_prob(context, context_len, scope=scope) prob_final*=prob return prob_final def log_prob_step(self, context, layer=-1): if layer==-1: layer=len(self.model) prob_final=tf.ones(tf.shape(context)[0], dtype=tf.float32) log_prob_list=[] for i in range(layer): item=self.model[i] scope=item['scope'] log_prob, prob=get_dis_logit_and_prob_single_step(context, scope=scope) log_prob_list.append(tf.expand_dims(log_prob, 1)) log_prob_final=tf.concat(log_prob_list, axis=1) return log_prob_final Dis=dis_class(layer_num=args.layer_num) context = tf.placeholder(tf.int32, [None, None]) context_len=tf.placeholder(tf.int32, [None]) context_mask=tf.sequence_mask(context_len-1, args.seq_len-1, dtype=tf.float32) context_in=context output = model.model(hparams=hparams, X=context_in) loss_tensor = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=context[:, 1:], logits=output['logits'][:, :-1])*context_mask loss=tf.reduce_sum(loss_tensor, axis=1)/(tf.reduce_sum(context_mask, axis=1)+1e-7) loss_sen=tf.reduce_sum(loss) loss=tf.reduce_mean(loss) if args.val_every > 0: def transform_np(x, lift=args.exponential_param): x=x-0.5 x=x+np.abs(x) return lift*x**2 def transform(x, lift=args.exponential_param): x=x-0.5 x=x+tf.abs(x) return lift*x**2 val_context = tf.placeholder(tf.int32, [args.val_batch_size, args.seq_len]) val_context_len=tf.placeholder(tf.int32, [args.batch_size]) NLL_bias=tf.placeholder(tf.float32, []) val_context_mask=tf.sequence_mask(val_context_len-1, args.seq_len-1, dtype=tf.float32) val_output = model.model(hparams=hparams, X=val_context) val_loss_tensor =tf.nn.sparse_softmax_cross_entropy_with_logits(labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])*val_context_mask val_context_prob_cut=Dis.prob(val_context, val_context_len) val_NLL_cut=tf.log(val_context_prob_cut+1e-7) val_loss=tf.reduce_sum(val_loss_tensor, axis=1)/(tf.reduce_sum(val_context_mask, axis=1)+1e-7) val_loss_cut=(tf.reduce_sum(val_loss_tensor, axis=1)+NLL_bias)/(tf.reduce_sum(val_context_mask, axis=1)+1e-7)-val_NLL_cut/tf.cast(val_context_len, tf.float32) val_loss_sum=tf.reduce_sum(val_loss_tensor, axis=1) val_loss_cut_sum=(tf.reduce_sum(val_loss_tensor, axis=1)+NLL_bias)-val_NLL_cut val_loss_mean=tf.reduce_mean(val_loss) val_loss_cut_mean=tf.reduce_mean(val_loss_cut) val_loss_summary = tf.summary.scalar('val_loss', val_loss_mean) tf_sample = sample.sample_sequence( hparams=hparams, length=args.seq_len, context=context, batch_size=args.batch_size, temperature=1.0, top_k=args.top_k, top_p=args.top_p, start_token=enc.encoder['<|endoftext|>']) start_token=enc.encoder['<|endoftext|>'] all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars if args.optimizer == 'adam': opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) elif args.optimizer == 'sgd': opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate) else: exit('Bad optimizer:', args.optimizer) if args.accumulate_gradients > 1: if args.memory_saving_gradients: exit("Memory saving gradients are not implemented for gradient accumulation yet.") opt = AccumulatingOptimizer( opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: if args.memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_lr = tf.summary.scalar('learning_rate', args.learning_rate) summaries = tf.summary.merge([summary_lr, summary_loss]) summary_log = tf.summary.FileWriter( os.path.join(CHECKPOINT_DIR, args.run_name)) saver = tf.train.Saver(var_list=all_vars, max_to_keep=1) sess.run(tf.global_variables_initializer()) if args.restore_from == 'latest': ckpt = tf.train.latest_checkpoint( os.path.join(CHECKPOINT_DIR, args.run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join('models', args.model_name)) elif args.restore_from == 'fresh': ckpt = tf.train.latest_checkpoint( os.path.join('models', args.model_name)) else: ckpt = tf.train.latest_checkpoint(args.restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') data_list, data_len = load_dataset(enc, args.train_data_path, args.seq_len) data_sampler = Sampler(data_list, data_len ) if args.val_every > 0: val_data_list, val_data_len = load_dataset(enc, args.eval_data_path, args.seq_len) print('dataset has', data_sampler.total_size, 'tokens') print('Training...') if args.val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_data_sampler = Sampler(val_data_list, val_data_len, seed=1) val_batches = [val_data_sampler.sample(args.batch_size) for _ in range(args.val_batch_count)] counter = 0 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) saver.save( sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') def train_step_discri(layer_id=0, mask_train_epoch=0): pos_samples, pos_samples_len=data_sampler.sample(args.batch_size) neg_samples=generate_negative_sample(layer_id=layer_id) neg_samples_len=get_array_len(neg_samples) _, loss=sess.run([Dis.model[layer_id]['train_op_discri'], Dis.model[layer_id]['loss_discri']], feed_dict={Dis.model[layer_id]['context_pos_discri']: pos_samples,Dis.model[layer_id]['context_pos_discri_len']: pos_samples_len, Dis.model[layer_id]['context_neg_discri']: neg_samples, Dis.model[layer_id]['context_neg_discri_len']: neg_samples_len}) return loss def generate_negative_samples(layer_id, generate_num=args.batch_size): result_list=[] generate_num_now=0 samples_mem=[] while generate_num_now<generate_num: t=time.time() sample_id=generate_negative_sample(layer_id=layer_id) samples=[] t1=time.time() selected_id_list=np.arange(len(sample_id)) t2=time.time() result_list.append(sample_id[selected_id_list]) generate_num_now+=len(selected_id_list) return np.concatenate(result_list, axis=0)[:generate_num] def get_array_len(sample_array): lens=[] for item in sample_array: for i in range(1, len(item)): if item[i]==enc.encoder['<|endoftext|>']: break lens.append(i) return np.array(lens).astype(np.int32) def generate_discri_sample3(layer_id=-1, sample_size=10000, save_path='/mnt/cephfs_new_wj/mlnlp/miaoning/Experiment/gpt-2-sep/samples/discri/sample2.txt'): samples=[] while len(samples)<sample_size: sample_id=generate_negative_sample(layer_id) for i in range(len(sample_id)): sample_tem=enc.decode(sample_id[i]).split('<|endoftext|>')[1].split('\n')[0] samples.append(sample_tem) print(len(samples)) with open(save_path, 'w') as g: g.write('\n'.join(samples)) def eval_discri_NLL(layer_id=0): losses_pos=[] losses_neg=[] for batch in tqdm.tqdm(val_batches): pos_samples, pos_samples_len=batch neg_samples=generate_negative_sample(layer_id=layer_id) neg_samples_len=get_array_len(neg_samples) loss_pos, mask=sess.run([Dis.model[layer_id]['loss_pos_discri'], Dis.model[layer_id]['mask']], feed_dict={Dis.model[layer_id]['context_pos_discri']: pos_samples, Dis.model[layer_id]['context_pos_discri_len']: pos_samples_len}) #print(mask) loss_neg=sess.run(Dis.model[layer_id]['loss_neg_discri'], feed_dict={Dis.model[layer_id]['context_neg_discri']: neg_samples, Dis.model[layer_id]['context_neg_discri_len']: neg_samples_len}) losses_pos.append(loss_pos) losses_neg.append(loss_neg) return np.mean(losses_pos), np.mean(losses_neg) def get_discri_quantile(layer_id=0, quantile=0.85): logits_list=[] for batch in tqdm.tqdm(val_batches): pos_samples, pos_samples_len=batch logits, mask=sess.run([Dis.model[layer_id]['logit_pos_discri'], Dis.model[layer_id]['mask']], feed_dict={Dis.model[layer_id]['context_pos_discri']: pos_samples, Dis.model[layer_id]['context_pos_discri_len']: pos_samples_len}) print(np.min(mask, axis=1)[:10]) print(logits[:10]) with open('mask.pkl', 'wb') as g: pkl.dump(mask, g) logits_list.extend(list(logits)) break with open('logits.pkl', 'wb') as g: pkl.dump(sorted(logits_list), g) #print(sorted(logits_list)) print('finish') return sorted(logits_list)[int(len(logits_list)*(1-quantile))] def train_discri(train_step, eval_every, train_layer_list=list(range(len(Dis.model)))): #sess.run(initializer_discri) print('Start Discri training') train_losses=[] for layer_id in train_layer_list: flag=0 for epoch in range(train_step): if epoch % eval_every==0: train_losses=np.mean(train_losses) train_losses=[] eval_NLL_pos, eval_NLL_neg=eval_discri_NLL(layer_id) eval_loss=(eval_NLL_pos*args.pos_loss_weight+eval_NLL_neg)/(args.pos_loss_weight+1) print('layer_id:{} discri eval loss:{}'.format(layer_id, eval_loss)) print('layer_id:{} discri NLL pos: {}, discri NLL neg: {}'.format(layer_id, eval_NLL_pos, eval_NLL_neg)) print(epoch) if epoch==0: eval_loss_old=eval_loss else: print(eval_loss, eval_loss_old) if eval_loss<eval_loss_old: eval_loss_old=eval_loss save_path=args.dis_save_path+str(layer_id)+'/' if not os.path.isdir(save_path): os.mkdir(save_path) Dis.model[layer_id]['saver_discri'].save(sess, save_path+'a') print('model discri saved!') flag=0 else: if epoch>=200: flag+=1 if flag>=4: break train_loss=train_step_discri(layer_id) print('layer_id:{} discri train loss:{}'.format(layer_id, train_loss)) train_losses.append(train_loss) return eval_loss_old tf_sample_0 = sample_link.sample_sequence( hparams=hparams, length=args.seq_len, context=context, batch_size=args.batch_size, temperature=1.0, top_k=args.top_k, top_p=args.top_p, start_token=enc.encoder['<|endoftext|>']) tf_sample_dict={} def generate_negative_sample(layer_id=0): ##output the filtered result of layer layer_id-1 if layer_id==0: tf_sample=tf_sample_0 sample = data_sampler.sample(args.batch_size)[0][:,0:1] out = sess.run( tf_sample, feed_dict={context: sample})[:,:args.seq_len] for i in range(len(out)): flag=0 for j in range(len(out[i])): if flag==2: out[i][j]=start_token continue if out[i][j]==start_token: flag+=1 return out else: if layer_id==-1: layer_id=len(Dis.model) if layer_id in tf_sample_dict: tf_sample=tf_sample_dict[layer_id] else: tf_sample = sample_link.sample_sequence_ISMC_threshold( Dis=Dis, layer=layer_id, hparams=hparams, length=args.seq_len, context=context, batch_size=args.batch_size, temperature=1.0, top_k=args.top_k, top_p=args.top_p, start_token=enc.encoder['<|endoftext|>']) tf_sample_dict[layer_id]=tf_sample sample = data_sampler.sample(args.batch_size)[0][:,0:1] out = sess.run( tf_sample, feed_dict={context: sample})[:,:args.seq_len] for i in range(len(out)): flag=0 for j in range(len(out[i])): if flag==2: out[i][j]=start_token continue if out[i][j]==start_token: flag+=1 return out def validation(): print('Calculating validation loss...') start_time=time.time() losses = [] rates=[] for batch in tqdm.tqdm(val_batches): losses.append(sess.run(val_loss_mean, feed_dict={val_context: batch[0], val_context_len: batch[1]})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss_mean: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print( '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_val_loss)) return v_val_loss def validation_cut(NLL_bias_0=0): print('Calculating validation loss...') losses = [] rates=[] for batch in tqdm.tqdm(val_batches): losses.append(sess.run(val_loss_cut_mean, feed_dict={val_context: batch[0], val_context_len: batch[1], NLL_bias:NLL_bias_0})) v_val_loss = np.mean(losses) print( '[{counter} | {time:2.2f}] validation cut loss = {loss:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_val_loss)) return v_val_loss def sample_batch(): return [data_sampler.sample(1024) for _ in range(args.batch_size)] def train_gpt(): val_loss_old=10000.0 avg_loss = (0.0, 0.0) start_time = time.time() counter=0 while True: #pretraining if counter % args.save_every == 0: pass #save() if counter % args.sample_every == 0: pass #generate_samples() if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1): val_loss_1=validation() print(str(counter //args.val_every)) if val_loss_1>=val_loss_old: print('pre-training ends!') break else: val_loss_old=val_loss_1 saver.save(sess, args.gpt_save_path+'a') print('save succeed!') if args.accumulate_gradients > 1: sess.run(opt_reset) for _ in range(args.accumulate_gradients): batch, batch_len=data_sampler.sample(args.batch_size) sess.run( opt_compute, feed_dict={context: batch, context_len:batch_len}) (v_loss, v_summary) = sess.run((opt_apply, summaries)) else: batch, batch_len=data_sampler.sample(args.batch_size) (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summaries), feed_dict={context: batch, context_len:batch_len}) summary_log.add_summary(v_summary, counter) avg_loss = (avg_loss[0] * 0.9 + v_loss, avg_loss[1] * 0.9 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 class log_writer: def __init__(self, path): self.path=path with open(path, 'w') as g: g.write('') def __call__(self, string, verbose=False): with open(self.path, 'a') as g: g.write(string+'\n') if verbose: print(string) try: if args.finetune: #Finetune GPT-2 train_gpt() if True: #Restore Finetuned model save_path=tf.train.latest_checkpoint(args.gpt_save_path) saver.restore(sess, save_path) print('Load gpt2 succeeded!') if args.evaluate_finetune: #Evaluate finetuning baseline print(validation()) if args.evaluate_finetune: #Calculate reverse-ppl for finetuning baseline sample_path=args.gpt_sample_dir2+'sample.txt' generate_discri_sample3(layer_id=0, sample_size=3000, save_path=sample_path) rev_ppl=train.file_f(train_data_path=sample_path, val_data_path=args.eval_data_path) Log_writer=log_writer(args.log_path+'finetune') Log_writer('finetuning_rev_ppl: {}'.format(rev_ppl), verbose=True) ##Begin tailoring if True: Log_writer=log_writer(args.log_path+'discri') for layer in range(args.layer_num): print(layer) if args.train_tailor: #Train ratio estimator train_discri(500, 10, [layer]) if True: #Restore ratio estimator for layer_id in range(layer+1): save_path=args.dis_save_path+str(layer_id)+'/' print(save_path) save_path=tf.train.latest_checkpoint(save_path) print(save_path) Dis.model[layer_id]['saver_discri'].restore(sess, save_path) if False: #Save quantile for analysis with open(args.dis_sample_dir2+'quantile.pkl', 'rb') as f: pkl.load(f) print('Load dis model succeeded!') if True: if layer==0: quantile=0.85 else: quantile=0.9 Dis.dis[layer]=get_discri_quantile(layer, quantile) with open(args.dis_sample_dir2+'quantile.pkl', 'wb') as g: pkl.dump(Dis.dis, g) print(Dis.dis) if args.evaluate_tailor: #Generate sample for ERS and calculate reverse-ppl sample_path=args.dis_sample_dir2+'_sample_layer_'+str(layer) generate_discri_sample3(layer_id=layer+1, sample_size=3000, save_path=sample_path) rev_ppl=train.file_f(train_data_path=sample_path, val_data_path=args.eval_data_path) Log_writer('layer: {}, dis_rev_ppl: {}'.format(layer, rev_ppl), verbose=True) except KeyboardInterrupt: print('interrupted')
def main(): args = parser.parse_args() folder_id = get_id(args.gdir) #xmpp = SendMsgBot(jid, password, to, "Starting GPT-2") #xmpp.register_plugin('xep_0030') # Service Discovery #xmpp.register_plugin('xep_0199') # XMPP Ping #xmpp.connect() #threading = Thread(target=xmpp.process, daemon=True).start() download_checkpoint(folder_id) #send_m('checkpoint downloaded') enc = encoder.get_encoder(args.model_name) hparams = model.default_hparams() with open(os.path.join('models', args.model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if args.sample_length > hparams.n_ctx: raise ValueError( "Can't get samples longer than window size: %s" % hparams.n_ctx) if args.model_name == '345M': args.memory_saving_gradients = True # if args.optimizer == 'adam': # args.only_train_transformer_layers = True config = tf.ConfigProto() config.gpu_options.allow_growth = True config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF with tf.Session(config=config) as sess: context = tf.placeholder(tf.int32, [args.batch_size, None]) context_in = randomize(context, hparams, args.noise) output = model.model(hparams=hparams, X=context_in) loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=context[:, 1:], logits=output['logits'][:, :-1])) if args.val_every > 0: val_context = tf.placeholder(tf.int32, [args.val_batch_size, None]) val_output = model.model(hparams=hparams, X=val_context) val_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) val_loss_summary = tf.summary.scalar('val_loss', val_loss) tf_sample = sample.sample_sequence( hparams=hparams, length=args.sample_length, context=context, batch_size=args.batch_size, temperature=1.0, top_k=args.top_k, top_p=args.top_p) all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars if args.optimizer == 'adam': opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) elif args.optimizer == 'sgd': opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate) else: exit('Bad optimizer:', args.optimizer) if args.accumulate_gradients > 1: if args.memory_saving_gradients: exit("Memory saving gradients are not implemented for gradient accumulation yet.") opt = AccumulatingOptimizer( opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: if args.memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_lr = tf.summary.scalar('learning_rate', args.learning_rate) summaries = tf.summary.merge([summary_lr, summary_loss]) summary_log = tf.summary.FileWriter(os.path.join(CHECKPOINT_DIR, args.run_name)) saver = tf.train.Saver(var_list=all_vars, max_to_keep=5, keep_checkpoint_every_n_hours=2) sess.run(tf.global_variables_initializer()) if args.restore_from == 'latest': ckpt = tf.train.latest_checkpoint(os.path.join(CHECKPOINT_DIR, args.run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint(os.path.join('models', args.model_name)) elif args.restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join('models', args.model_name)) else: ckpt = tf.train.latest_checkpoint(args.restore_from) print('Loading checkpoint', ckpt) #send_m('Loading ' + str(ckpt)) saver.restore(sess, ckpt) print('Loading dataset...') #send_m('Loading dataset...') #chunks = load_dataset(enc, args.dataset, args.combine) ds_path = f'{CHECKPOINT_DIR}//run1//{args.dataset}' chunks = load_dataset(enc, ds_path, args.combine) data_sampler = Sampler(chunks) print(f'{ds_path} has', data_sampler.total_size, 'tokens') if args.val_every > 0: val_chunks = load_dataset(enc, args.val_dataset, args.combine) if args.val_dataset else chunks if args.enc: print(colored(f'Trying writing Data.npz encoded from this dataset to {args.enc}', 'red')) np.savez_compressed(args.enc, *chunks) upload_npz(args.enc, folder_id) #send_m(f'{args.dataset} has ' + str(data_sampler.total_size) + ' tokens' + ' Start training...') print('Training...') if args.val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_data_sampler = Sampler(val_chunks, seed=1) val_batches = [[val_data_sampler.sample(1024) for _ in range(args.val_batch_size)] for _ in range(args.val_batch_count)] counter = 1 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) saver.save( sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') save_gdisk(counter, folder_id) def generate_samples(): print('Generating samples...') #send_m('Generating samples...') context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < args.sample_num: out = sess.run( tf_sample, feed_dict={context: args.batch_size * [context_tokens]}) for i in range(min(args.sample_num - index, args.batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) #send_m(text) maketree(os.path.join(SAMPLE_DIR, args.run_name)) with open( os.path.join(SAMPLE_DIR, args.run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) def validation(): print('Calculating validation loss...') losses = [] for batch in tqdm.tqdm(val_batches): losses.append(sess.run(val_loss, feed_dict={val_context: batch})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print( '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_val_loss)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(args.batch_size)] avg_loss = (0.0, 0.0) start_time = time.time() last_time = time.time() cur_counter, min_loss = 1, 2.0 print(colored(f'Model >>> {args.gdir}\nLearning rate is {args.learning_rate}', 'blue')) print(colored(f'model optimizer >>> {args.optimizer}\nRestricted to train only transformer layer={args.only_train_transformer_layers}', 'blue')) #send_m(f'Model >>> {args.model_name}\nLearning rate is {args.learning_rate}') try: while True: if counter % args.save_every == 0: save() if check_quota(): return() # exit train if counter % args.sample_every == 0: generate_samples() if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1): validation() if args.accumulate_gradients > 1: sess.run(opt_reset) for _ in range(args.accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summaries)) else: (_, v_loss, v_summary) = sess.run((opt_apply, loss, summaries), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) a_loss = avg_loss[0] / avg_loss[1] time_all = int((time.time() - start_time) / 60) time_iter = time.time() - last_time stats = f'[{counter} | {cur_counter} | {time_all}m | {time_iter:2.2f}s] loss={v_loss:2.2f} avg={a_loss:2.2f}' if not(cur_counter % 50): print(colored(stats, 'red' if a_loss > min_loss else 'yellow')) if a_loss < min_loss: min_loss = a_loss #send_m(stats) last_time = time.time() counter += 1 cur_counter += 1 except Exception as e: #send_m('Stoped ' + str(e.__class__)) print('Stoped', e.__class__)
def main(): args = parser.parse_args() enc = get_encoder(model_name) hparams = model.default_hparams() with open(os.path.join(model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if args.sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) if args.model_name == '345M': args.memory_saving_gradients = True args.only_train_transformer_layers = True config = tf.ConfigProto() config.gpu_options.allow_growth = True config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF acc_total = 0 acc_over_time = [] loss_avg_over_time = [] if args.val_every > 0: # val_context = tf.placeholder(tf.int32, [args.val_batch_size, None]) val_context = tf.placeholder(np.int32, [1, None]) val_output = model.model(hparams=hparams, X=val_context) val_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) val_loss_summary = tf.summary.scalar('val_loss', val_loss) tf_sample_val = sample.sample_sequence( hparams=hparams, length=1, #args.sample_length, context=val_context, batch_size=1, #args.batch_size, temperature=10.001, top_k=1) with tf.Session(config=config) as sess: context = tf.placeholder(tf.int32, [args.batch_size, None]) output = model.model(hparams=hparams, X=context) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) tf_sample = sample.sample_sequence(hparams=hparams, length=args.sample_length, context=context, batch_size=args.batch_size, temperature=1.0, top_k=40) all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name ] if args.only_train_transformer_layers else all_vars if args.accumulate_gradients > 1: if args.memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer( opt=tf.train.AdamOptimizer(learning_rate=args.learning_rate), var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) if args.memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_log = tf.summary.FileWriter( os.path.join(CHECKPOINT_DIR, args.run_name)) saver = tf.train.Saver(var_list=all_vars, max_to_keep=5, keep_checkpoint_every_n_hours=2) sess.run(tf.global_variables_initializer()) if args.restore_from == 'latest': ckpt = tf.train.latest_checkpoint( os.path.join(CHECKPOINT_DIR, args.run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint(os.path.join(model_name)) elif args.restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join(model_name)) else: ckpt = tf.train.latest_checkpoint(args.restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading train dataset...') from_name, ques_name, to_name = name_parts( args.dataset) #'../data/train.from') trn_chunks_from = load_dataset( enc, from_name, args.combine) if args.val_dataset else chunks trn_chunks_ques = load_dataset( enc, ques_name, args.combine) if args.val_dataset else chunks trn_chunks_to = load_dataset( enc, to_name, args.combine) if args.val_dataset else chunks skip_delimeter = True trn_data_sampler_from = SamplerVal(trn_chunks_from, enc, skip_delimeter=skip_delimeter) trn_data_sampler_ques = SamplerVal(trn_chunks_ques, enc, skip_delimeter=skip_delimeter) trn_data_sampler_to = SamplerVal(trn_chunks_to, enc, skip_delimeter=skip_delimeter) data_sampler = [] for i in range(trn_data_sampler_from.total_size): v = ( trn_data_sampler_from.get(i) + trn_data_sampler_ques.get(i) + enc.encode('. ') + trn_data_sampler_to.get(i) # + #enc.encode('<|endoftext|>') ) # v += [enc.encode(' ')[0] for _ in range(HIDDEN_SIZE - len(v) )] if len(v) >= HIDDEN_SIZE - GENERATE_SIZE: continue v = v[:HIDDEN_SIZE - 1] data_sampler.append(v) pass #chunks = load_dataset(enc, args.dataset, args.combine) if not args.train_special: data_sampler = Sampler([np.array(data_sampler)]) if args.val_every > 0: print('Loading validation dataset...') #val_chunks = load_dataset(enc, args.val_dataset, args.combine) if args.val_dataset else chunks from_name, ques_name, to_name = name_parts(args.val_dataset) val_chunks_from = load_dataset( enc, from_name, args.combine) if args.val_dataset else chunks val_chunks_ques = load_dataset( enc, ques_name, args.combine) if args.val_dataset else chunks val_chunks_to = load_dataset( enc, to_name, args.combine) if args.val_dataset else chunks if not args.train_special: print('train dataset has', data_sampler.total_size, 'tokens') else: print('train dataset has', len(data_sampler), 'tokens') print('Training...') if args.val_every > 0: val_data_sampler_from = SamplerVal(val_chunks_from, enc) val_data_sampler_ques = SamplerVal(val_chunks_ques, enc) val_data_sampler_to = SamplerVal(val_chunks_to, enc) if args.val_batch_count == -1: args.val_batch_count = val_data_sampler_from.total_size val_batches = [] for i in range(args.val_batch_count): v = (val_data_sampler_from.get(i) + val_data_sampler_ques.get(i) + enc.encode('. ') ) #+ val_data_sampler_to.get(i) #v += [enc.encode(' ')[0] for _ in range(HIDDEN_SIZE - len(v) )] if len(v) >= HIDDEN_SIZE - GENERATE_SIZE: continue v = v[:HIDDEN_SIZE] val_batches.append(v) pass print('val dataset has', len(val_batches), 'tokens') counter = 1 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 txt_file_path = os.path.join(CHECKPOINT_DIR, args.run_name, args.run_name + '.summary.txt') def save_summary(message=None): if message is None: txt = '' fmt = '{valid:2.2f}' if not os.path.exists(txt_file_path): a = vars(args) txt += 'Summary for ' + args.run_name + '\n' txt += str(datetime.datetime.now()) + '\n\n' txt += json.dumps(a) + '\n' txt += '-----\n' pass txt += str(datetime.datetime.now()) + '\n' txt += 'acc: ' + ', '.join( [fmt.format(valid=i) for i in acc_over_time]) + '\n' txt += 'loss: ' + ', '.join( [fmt.format(valid=i) for i in loss_avg_over_time]) + '\n' txt += 'counter: ' + str(counter) + '\n' txt += 'time elapsed: ' + str(time.time() - start_time) + '\n' txt += '-----\n' else: txt = message print(txt) with open(txt_file_path, 'a') as f: f.write(txt + '\n') def save(): if args.test: return maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) saver.save(sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') ''' def generate_samples(): print('Generating samples...') context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < args.sample_num: out = sess.run( tf_sample, feed_dict={context: args.batch_size * [context_tokens]}) for i in range(min(args.sample_num - index, args.batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, args.run_name)) with open( os.path.join(SAMPLE_DIR, args.run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) ''' def print_status(word=None, acc_total_in=0, size=0, v_loss_in=0.0, shorten=False): v_loss = v_loss_in acc_out = 0 acc_total = 0 loss_out = 0.0 if word is None: word = 'progress' if acc_total_in != 0 and size != 0: acc_out = acc_total_in / size * 100 acc_total = size pass if avg_loss[1] == 0.0 or avg_loss[0] == 0.0: loss_out = 0.0 v_loss = 0.0 pass elif not np.isnan( avg_loss[0]) or True: # and not np.isnan(avg_loss[1]): loss_out = avg_loss[0] / avg_loss[1] print(word + ' [' + args.run_name + ']' + ' [{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'. format(counter=counter, time=time.time() - start_time, loss=v_loss, avg=loss_out), 'acc=' + str(acc_out), end=' ') print('total=' + str(acc_total), end=' ') if len(acc_over_time) > 0 and not shorten: print('last-acc=' + str(acc_over_time[-1])) else: print() pass def sample_batch(counter=0, randomize=False, pad_start=False): #print(enc.encode('<|endoftext|>'), 'eot') #print(data_sampler.sample(1024)) if not args.train_special: return [ data_sampler.sample(HIDDEN_SIZE)[0] for _ in range(args.batch_size) ] else: num = 0 z = [] while (len(z) > HIDDEN_SIZE or len(z) == 0) and num <= 5: if randomize: r = random.randint(1, 4) else: r = 0 #print('train special', r) if pad_start: pad = HIDDEN_SIZE - (len(data_sampler[counter]) - r) else: pad = 0 if randomize: z = [[enc.encode(' ')[0] for _ in range(pad)] + data_sampler[counter][:-r] for _ in range(args.batch_size)] if not randomize: z = [data_sampler[counter]] #print(enc.decode(z[0])) num += 1 if num == 5: z = z[len(z) - HIDDEN_SIZE:] print('cannot get sample_batch') break return z def validation_by_sample(): print('Generating validation...') global acc_total if args.val_with_loss: losses = [] for batch in tqdm.tqdm(val_batches): batch = np.reshape(batch, [1, -1]) v = sess.run(val_loss, feed_dict={val_context: batch}) #print(v, 'v') losses.append(v) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print( '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'. format(counter=counter, time=time.time() - start_time, loss=v_val_loss)) acc_total = 0 generated = 0 for _ in range(len(val_batches)): val_batches_in = val_batches[generated] val_batches_in = val_batches_in[:1024] context_tokens = np.reshape(val_batches_in, [1, -1]) #print(val_batches_in) text_in = enc.decode(val_batches_in) #print(text_in) #print(context_tokens, 'ct1') for x in range(GENERATE_SIZE): out = sess.run(tf_sample_val, feed_dict={val_context: context_tokens}) #print(out[0][-x:]) #print(enc.decode(out[0][-x:])) context_tokens = out compare = enc.decode( val_data_sampler_to.get(generated)) # + ' <|endoftext|>' compare = ' '.join(compare.split(' ')) generated += 1 text = enc.decode(out[0]) text_returned = '' text_original = '' if text.startswith(text_in): text_returned = text[len(text_in):] #print('-',text_returned,'-') if args.train_special: text_original = text text = text_returned if text.strip().endswith('.'): ## remove trailing period text = text.strip()[:-1] if text.strip().endswith('<|endoftext|>'): text = text.strip()[:-len('<|endoftext|>')] t_vals = text.split(' ') if '<' in t_vals[-1] or '>' in t_vals[-1]: t_vals = t_vals[:-1] num = 0 while t_vals[-1] == '' and num < 10: t_vals = t_vals[:-1] num += 1 #print(t_vals) t_vals = [i for i in t_vals if i != ''] #print(t_vals) text = ' '.join(t_vals) if compare.strip().endswith('.'): compare = compare.strip()[:-1] if compare.strip().endswith('<|endoftext|>'): compare = compare.strip()[:-len('<|endoftext|>')] notification = '' len_bar = 40 if text.strip().lower().endswith(compare.strip().lower()): acc_total += 1 notification = 'vv CORRECT vv' len_bar = 40 - len(notification) elif text_returned.strip().lower().startswith( compare.strip().lower()): acc_total += 1 notification = 'vv CORRECT_INITIAL vv' len_bar = 40 - len(notification) print(notification + "=" * len_bar + " SAMPLE " + str(generated) + " " + "=" * len_bar + notification) if args.train_special: print(text_original) else: print(text) print_status('old values', acc_total_in=acc_total, size=generated) print("=" * 80) return acc_total pass avg_loss = (0.0, 0.0) start_time = time.time() count_success = 0 count_success_with_skips = 0 acc = 0.0 try: if args.test: v_loss = 0.0 dataset = re.sub('train', 'test', args.dataset) print(dataset) from_name, ques_name, to_name = name_parts(dataset) test_chunks_from = load_dataset(enc, from_name, args.combine) test_chunks_ques = load_dataset(enc, ques_name, args.combine) test_chunks_to = load_dataset(enc, to_name, args.combine) val_data_sampler_from = SamplerVal(test_chunks_from, enc) val_data_sampler_ques = SamplerVal(test_chunks_ques, enc) val_data_sampler_to = SamplerVal(test_chunks_to, enc) if args.val_batch_count == -1: args.val_batch_count = val_data_sampler_from.total_size val_batches = [] for i in range(args.val_batch_count): v = (val_data_sampler_from.get(i) + val_data_sampler_ques.get(i) + enc.encode('. ') ) # + val_data_sampler_to.get(i) # v += [enc.encode(' ')[0] for _ in range(HIDDEN_SIZE - len(v) )] if len(v) >= HIDDEN_SIZE - GENERATE_SIZE: continue val_batches.append(v) acc_total = validation_by_sample() acc = acc_total / len(val_batches) * 100 print(acc, 'test accuracy') save_summary('Accuracy with test set ' + str(acc) + '\n') exit() while counter != args.stop_after: #model_summary() if counter % args.save_every == 0: save() if counter % args.sample_every == 0: #generate_samples() pass if args.val_every > 0 and (counter % args.val_every == 0): # or counter == 1): acc_total = validation_by_sample() acc = acc_total / len(val_batches) * 100 acc_over_time.append(acc) if avg_loss[1] > 0.0: loss_avg_over_time.append(avg_loss[0] / avg_loss[1]) else: loss_avg_over_time.append(0) counter_in = counter % len(val_batches) if args.accumulate_gradients > 1: sess.run(opt_reset) for _ in range(args.accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch(counter_in)}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch(counter_in)}) summary_log.add_summary(v_summary, counter) #if True: if not np.isnan(avg_loss[0]) and not np.isnan(avg_loss[1]): avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) if counter % args.val_every == 1: if float(acc) == 100.0: #save() print('validation accuracy 100', time.time() - start_time) count_success += 1 count_success_with_skips += 1 if count_success >= 2 or count_success_with_skips >= 4: #save_summary() exit() else: count_success = 0 print_status(acc_total_in=acc_total, size=len(val_batches), v_loss_in=v_loss, shorten=True) counter += 1 except KeyboardInterrupt: print('interrupted') finally: save() save_summary() print('save weights/summary and exit.')
def finetune(sess, dataset, steps=-1, model_name='124M', model_dir='models', combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from='latest', run_name='run1', checkpoint_dir='checkpoint', sample_every=100, sample_length=1023, sample_num=1, multi_gpu=False, save_every=1000, print_every=1, max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, optimizer='adam', overwrite=False, val_dataset=None, val_batch_size=2, val_batch_count=40, val_every=0): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ # assert model_name not in ['774M', '1558M'] or multi_gpu, "Currently, a modern single GPU cannot finetune the 774M GPT-2 model or larger." SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(checkpoint_dir, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) files = [f for f in os.listdir(checkpoint_path)] for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: try: shutil.copyfile(os.path.join(model_dir, model_name, file), os.path.join(checkpoint_path, file)) except FileNotFoundError as fnf_error: print( "You need to download the GPT-2 model first via download_gpt2()" ) raise (fnf_error) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) if model_name not in ['117M', '124M']: use_memory_saving_gradients = True only_train_transformer_layers = True accumulate_gradients = 1 context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) gpus = [] if multi_gpu: gpus = get_available_gpus() output = model.model(hparams=hparams, X=context, gpus=gpus) loss = tf.reduce_mean( input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) # validation code if val_every > 0: val_context = tf.placeholder(tf.int32, [val_batch_size, None]) val_output = model.model(hparams=hparams, X=val_context, reuse=True) # added reuse=True val_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) val_loss_summary = tf.summary.scalar('val_loss', val_loss) tf_sample = sample.sample_sequence(hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40) all_vars = [ v for v in tf.compat.v1.trainable_variables() if 'model' in v.name ] train_vars = [v for v in all_vars if '/h' in v.name ] if only_train_transformer_layers else all_vars if optimizer == 'adam': opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) elif optimizer == 'sgd': opt = tf.compat.v1.train.GradientDescentOptimizer( learning_rate=learning_rate) if accumulate_gradients > 1: if use_memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer(opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply) else: if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(ys=loss, xs=train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.compat.v1.summary.scalar('loss', loss) summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path) saver = tf.compat.v1.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.compat.v1.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join(model_dir, model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) # validation code if val_every > 0: if val_dataset: val_chunks = load_dataset(enc, val_dataset, combine) else: val_chunks = chunks print('dataset has', data_sampler.total_size, 'tokens') print('Training...') # validation code if val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_data_sampler = Sampler(val_chunks, seed=1) val_batches = [[ val_data_sampler.sample(1024) for _ in range(val_batch_size) ] for _ in range(val_batch_count)] counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 counter_base = counter def save(): maketree(checkpoint_path) print('Saving', os.path.join(checkpoint_path, 'model-{}').format(counter - 1)) saver.save(sess, os.path.join(checkpoint_path, 'model'), global_step=counter - 1) with open(counter_path, 'w') as fp: fp.write(str(counter - 1) + '\n') def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run(tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) # validation code def validation(): print('Calculating validation loss...') losses = [] for batch in tqdm(val_batches): losses.append(sess.run(val_loss, feed_dict={val_context: batch})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.format( counter=counter, time=time.time() - start_time, loss=v_val_loss)) return v_val_loss def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] if overwrite and restore_from == 'latest': for file in files: if file.startswith('model') or file.startswith('events'): os.remove(os.path.join(checkpoint_path, file)) save() avg_loss = (0.0, 0.0) start_time = time.time() #Trying out a change to finetune that saves only when validation loss decreases if steps: steps = int(steps) try: while True: if steps > 0 and counter == (counter_base + steps): #save() return # if (counter - 1) % save_every == 0 and counter > 1: # save() if (counter - 1) % sample_every == 0 and counter > 1: generate_samples() # validation code if val_every > 0 and counter == 1: v_val_loss = validation() save() elif val_every > 0 and counter == counter_base: v_val_loss = validation() elif val_every > 0 and (counter % val_every == 0): new_v_val_loss = validation() if new_v_val_loss < v_val_loss: v_val_loss = new_v_val_loss save() if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) if (counter % print_every == 0) or counter == 1: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save()