def main(): args = parser.parse_args() enc = encoder.get_encoder(args.model_name) hparams = model.default_hparams() with open(os.path.join('models', args.model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if args.sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) if args.model_name == '355M': args.memory_saving_gradients = True if args.optimizer == 'adam': args.only_train_transformer_layers = True config = tf.ConfigProto() config.gpu_options.allow_growth = True config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF with tf.Session(config=config) as sess: context = tf.placeholder(tf.int32, [args.batch_size, None]) context_in = randomize(context, hparams, args.noise) output = model.model(hparams=hparams, X=context_in) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) if args.val_every > 0: val_context = tf.placeholder(tf.int32, [args.val_batch_size, None]) val_output = model.model(hparams=hparams, X=val_context) val_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) val_loss_summary = tf.summary.scalar('val_loss', val_loss) tf_sample = sample.sample_sequence(hparams=hparams, length=args.sample_length, context=context, batch_size=args.batch_size, temperature=1.0, top_k=args.top_k, top_p=args.top_p) all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name ] if args.only_train_transformer_layers else all_vars if args.optimizer == 'adam': opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) elif args.optimizer == 'sgd': opt = tf.train.GradientDescentOptimizer( learning_rate=args.learning_rate) else: exit('Bad optimizer:', args.optimizer) if args.accumulate_gradients > 1: if args.memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer(opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: if args.memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_lr = tf.summary.scalar('learning_rate', args.learning_rate) summaries = tf.summary.merge([summary_lr, summary_loss]) summary_log = tf.summary.FileWriter( os.path.join(CHECKPOINT_DIR, args.run_name)) saver = tf.train.Saver(var_list=all_vars, max_to_keep=5, keep_checkpoint_every_n_hours=2) sess.run(tf.global_variables_initializer()) if args.restore_from == 'latest': ckpt = tf.train.latest_checkpoint( os.path.join(CHECKPOINT_DIR, args.run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join('models', args.model_name)) elif args.restore_from == 'fresh': ckpt = tf.train.latest_checkpoint( os.path.join('models', args.model_name)) else: ckpt = tf.train.latest_checkpoint(args.restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, args.dataset, args.combine, encoding=args.encoding) data_sampler = Sampler(chunks) if args.val_every > 0: if args.val_dataset: val_chunks = load_dataset(enc, args.val_dataset, args.combine, encoding=args.encoding) else: val_chunks = chunks print('dataset has', data_sampler.total_size, 'tokens') print('Training...') if args.val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_data_sampler = Sampler(val_chunks, seed=1) val_batches = [[ val_data_sampler.sample(1024) for _ in range(args.val_batch_size) ] for _ in range(args.val_batch_count)] counter = 1 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) saver.save(sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') def generate_samples(): print('Generating samples...') context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < args.sample_num: out = sess.run( tf_sample, feed_dict={context: args.batch_size * [context_tokens]}) for i in range(min(args.sample_num - index, args.batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text.encode('utf8')) maketree(os.path.join(SAMPLE_DIR, args.run_name)) with open(os.path.join(SAMPLE_DIR, args.run_name, 'samples-{}').format(counter), 'w', encoding=args.encoding) as fp: fp.write('\n'.join(all_text)) def validation(): print('Calculating validation loss...') losses = [] for batch in tqdm.tqdm(val_batches): losses.append( sess.run(val_loss, feed_dict={val_context: batch})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'. format(counter=counter, time=time.time() - start_time, loss=v_val_loss)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(args.batch_size)] avg_loss = (0.0, 0.0) start_time = time.time() try: while True: if counter % args.save_every == 0: save() if counter % args.sample_every == 0: generate_samples() if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1): validation() if args.accumulate_gradients > 1: sess.run(opt_reset) for _ in range(args.accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summaries)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summaries), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save()
def train_main(dataset, model_name='1250M', seed=None, msg=True, batch_size=16, learning_rate=0.00002, sample_length=512, sample_num=1, sample_every=100, run_name='run1', restore_from='latest', save_every=1000, combine=50000): enc = encoder.get_encoder(model_name) hparams = model.default_hparams() with open(os.path.join('models', model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) print('n_ctx: ', hparams.n_ctx, 'n_head: ', hparams.n_head, 'n_embd: ', hparams.n_embd, 'n_layer: ', hparams.n_layer) if sample_length is None: sample_length = hparams.n_ctx elif sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) # TF config config = tf.ConfigProto() #device_map = { 0:2, 0:3, 1:2, 1:3 } #config.gpu_options.visible_device_list = str(device_map[hvd.rank()]) config.gpu_options.visible_device_list = str(hvd.local_rank()) config.gpu_options.allow_growth = True global_step = tf.Variable(0, trainable=False) with tf.Session(config=config) as sess: context = tf.placeholder(tf.int32, [batch_size, None]) np.random.seed(seed) tf.set_random_seed(seed) output = model.model(hparams=hparams, X=context) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) tf_sample = sample.sample_sequence(hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=0.9, top_k=40) #global_step = tf.Variable(0, trainable=False) counter = 1 train_vars = [v for v in tf.trainable_variables() if 'model' in v.name] #opt = tf.train.AdamOptimizer(learning_rate=learning_rate) # l4rz 11/10/2019 decayed_lr = tf.train.exponential_decay(learning_rate, global_step, 200, 0.999, staircase=True) opt = tf.train.AdamOptimizer(decayed_lr) #opt = tf.train.GradientDescentOptimizer(decayed_lr) opt = hvd.DistributedOptimizer(opt) # this is original horovod #train_op = opt.minimize(loss, var_list=train_vars) # this is ours if (msg): print('Using memory saving gradients') opt_grads = memory_saving_gradients.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) train_op = opt.apply_gradients(opt_grads, global_step=global_step) else: print('Not using memory saving gradients') #train_op = opt.minimize(loss, var_list=train_vars) # l4rz 11/10 train_op = opt.minimize(loss, var_list=train_vars, global_step=global_step) # [1,2]<stderr>:TypeError: apply_gradients() missing 1 required positional argument: 'grads_and_vars' #summary_loss = tf.summary.scalar('loss', train_op) #_, lv = sess.run((train_op, loss), feed_dict={context: batch}) # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. print('Running hvd.broadcast_global_variables') bcast = hvd.broadcast_global_variables(0) print('Done') saver = tf.train.Saver(var_list=train_vars, max_to_keep=5, keep_checkpoint_every_n_hours=2) print('Running global_variables_initializer') sess.run(tf.global_variables_initializer()) print('Done') if restore_from == 'latest': ckpt = tf.train.latest_checkpoint( os.path.join(CHECKPOINT_DIR, run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join('models', model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint( os.path.join('models', model_name)) # comment out when running for 1st time else: ckpt = tf.train.latest_checkpoint(restore_from) print(str(hvd.local_rank()), 'Loading checkpoint', ckpt) saver.restore(sess, ckpt) # uncomment when running for first time INIT THE MODEL #print('tf.global_variables_initializer()') #sess.run(tf.global_variables_initializer()) bcast.run() print(str(hvd.local_rank()), 'Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print(str(hvd.local_rank()), 'dataset has', data_sampler.total_size, 'tokens') print(str(hvd.local_rank()), 'Training...') counter = 1 if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'), 'r') as fp: counter = int(fp.read()) + 1 def save(): maketree(os.path.join(CHECKPOINT_DIR, run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, run_name, 'model-{}').format(counter)) saver.save(sess, os.path.join(CHECKPOINT_DIR, run_name, 'model'), global_step=counter) with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'), 'w') as fp: fp.write(str(counter) + '\n') def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run( tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) avg_loss = (0.0, 0.0) start_time = time.time() try: while True: batch = [data_sampler.sample(1024) for _ in range(batch_size)] _, lv = sess.run((train_op, loss), feed_dict={context: batch}) avg_loss = (avg_loss[0] * 0.99 + lv, avg_loss[1] * 0.99 + 1.0) if hvd.rank() == 0: if counter % save_every == 0: save() if counter % sample_every == 0: generate_samples() print( '[{counter} | {time:2.2f}] loss={loss:2.4f} avg={avg:2.4f} lr={lr:.2e}' .format(counter=counter, time=time.time() - start_time, loss=lv, avg=avg_loss[0] / avg_loss[1], lr=decayed_lr.eval())) counter += 1 except KeyboardInterrupt: print('interrupted') if hvd.rank() == 0: save()
def main(): args = parser.parse_args() try: logdir = os.path.join(CHECKPOINT_DIR, args.run_name) with open('logdir.txt', 'w') as z: z.write(logdir) except: pass enc = get_encoder(model_name) hparams = model.default_hparams() with open(os.path.join(model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if args.sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) if args.model_name == '345M': args.memory_saving_gradients = True args.only_train_transformer_layers = True config = tf.ConfigProto() config.gpu_options.allow_growth = True config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF with tf.Session(config=config) as sess: context = tf.placeholder(tf.int32, [args.batch_size, None]) output = model.model(hparams=hparams, X=context) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) if args.val_every > 0: val_context = tf.placeholder(tf.int32, [args.val_batch_size, None]) val_output = model.model(hparams=hparams, X=val_context) val_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) val_loss_summary = tf.summary.scalar('val_loss', val_loss) tf_sample = sample.sample_sequence(hparams=hparams, length=args.sample_length, context=context, batch_size=args.batch_size, temperature=1.0, top_k=40) all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name ] if args.only_train_transformer_layers else all_vars if args.accumulate_gradients > 1: if args.memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer( opt=tf.train.AdamOptimizer(learning_rate=args.learning_rate), var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) if args.memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_log = tf.summary.FileWriter( os.path.join(CHECKPOINT_DIR, args.run_name)) saver = tf.train.Saver(var_list=all_vars, max_to_keep=5, keep_checkpoint_every_n_hours=2) sess.run(tf.global_variables_initializer()) if args.restore_from == 'latest': ckpt = tf.train.latest_checkpoint( os.path.join(CHECKPOINT_DIR, args.run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint(os.path.join(model_name)) elif args.restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join(model_name)) else: ckpt = tf.train.latest_checkpoint(args.restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) #print('Loading dataset...') #chunks = load_dataset(enc, args.dataset, args.combine) #data_sampler = Sampler(chunks) print('Loading train dataset...') from_name, ques_name, to_name = name_parts(args.dataset) trn_chunks_from = load_dataset( enc, from_name, args.combine) #if args.dataset else chunks #trn_chunks_ques = load_dataset(enc, ques_name, args.combine) if args.dataset else chunks trn_chunks_to = load_dataset( enc, to_name, args.combine) #if args.dataset else chunks skip_delimeter = True char = '\t' trn_data_sampler_from = SamplerVal(trn_chunks_from, enc, char=char, skip_delimeter=skip_delimeter) #trn_data_sampler_ques = SamplerVal(trn_chunks_ques, enc, char=char, skip_delimeter=skip_delimeter) trn_data_sampler_to = SamplerVal(trn_chunks_to, enc, char=char, skip_delimeter=skip_delimeter) len_v = 0 data_sampler = [] for i in range(trn_data_sampler_from.total_size): v = ( #enc.encode('\nQ: ') + trn_data_sampler_from.get(i) + #enc.encode('. \nA: ') + trn_data_sampler_to.get(i) # + #enc.encode('. ') ) v = v[:HIDDEN_SIZE - 1] len_v += len(v) #data_sampler.extend(v) ## data_sampler.append(v) pass if len_v < HIDDEN_SIZE: mult = HIDDEN_SIZE // len_v + 1 for i in range(mult): x = data_sampler[:] data_sampler.extend(x) data_sampler = Sampler([np.array(data_sampler)]) #if not args.train_special and len_v >= HIDDEN_SIZE: # data_sampler = Sampler([np.array(data_sampler)]) if args.val_every > 0 and False: val_chunks = load_dataset( enc, args.val_dataset, args.combine) if args.val_dataset else chunks if not isinstance(data_sampler, list): print('dataset has', data_sampler.total_size, 'tokens') print('Training...') if args.val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_data_sampler = Sampler(val_chunks, seed=1) val_batches = [[ val_data_sampler.sample(1024) for _ in range(args.val_batch_size) ] for _ in range(args.val_batch_count)] counter = 1 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) saver.save(sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') #print(model_name, 'mn') GPT2_DIR_X = model_name cd = CHECKPOINT_DIR + "/" + args.run_name if not os.path.isfile(cd + '/' + 'encoder.json'): os.system("cp " + GPT2_DIR_X + '/' + 'encoder.json ' + cd + '/.') os.system('cp ' + GPT2_DIR_X + "/" + 'vocab.bpe ' + cd + '/.') def generate_samples(): print('Generating samples...') #context_tokens = data_sampler.sample(1) #context_tokens = data_sampler[0] context_tokens = trn_data_sampler_from.get( random.randint(0, trn_data_sampler_from.total_size)) #print(enc.decode(context_tokens), len(context_tokens)) #print(args.batch_size * [context_tokens]) all_text = [] index = 0 while index < args.sample_num: out = sess.run( tf_sample, feed_dict={context: args.batch_size * [context_tokens]}) for i in range(min(args.sample_num - index, args.batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, args.run_name)) with open( os.path.join(SAMPLE_DIR, args.run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) def validation(): print('Calculating validation loss...') losses = [] for batch in tqdm.tqdm(val_batches): losses.append( sess.run(val_loss, feed_dict={val_context: batch})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'. format(counter=counter, time=time.time() - start_time, loss=v_val_loss)) def sample_batch(): #z = [data_sampler.sample(1024) for _ in range(args.batch_size)] #print(len(data_sampler)) #print(len(data_sampler[0])) z = [data_sampler[random.randint(0, args.batch_size)]] #print(enc.decode(z[0])) #print(z[1],'\n1' ,z[2],'\n2' ,z[3] ,len(data_sampler[0])) #exit() return z avg_loss = (0.0, 0.0) start_time = time.time() try: while counter != args.stop_after: if counter % args.save_every == 0: save() if counter % args.sample_every == 0: generate_samples() pass if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1): validation() if args.accumulate_gradients > 1: sess.run(opt_reset) for _ in range(args.accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('\ninterrupted') finally: save()
def train_main(dataset, model_name='117M', seed=None, batch_size=2, sample_length=1023, sample_num=1, sample_every=4500, run_name='run1', restore_from='latest', save_every=2000, combine=50000): enc = encoder.get_encoder(model_name) hparams = model.default_hparams() with open( os.path.join('chatbot_model', 'trained_models', model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if sample_length is None: sample_length = hparams.n_ctx // 2 elif sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) # TF config config = tf.ConfigProto() config.gpu_options.visible_device_list = str(hvd.local_rank()) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: context = tf.placeholder(tf.int32, [batch_size, None]) np.random.seed(seed) tf.set_random_seed(seed) output = model.model(hparams=hparams, X=context) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) tf_sample = sample.sample_sequence(hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=0.8, top_k=40) train_vars = [v for v in tf.trainable_variables() if 'model' in v.name] opt = tf.train.AdamOptimizer() opt = hvd.DistributedOptimizer(opt) train_op = opt.minimize(loss, var_list=train_vars) # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. bcast = hvd.broadcast_global_variables(0) saver = tf.train.Saver(var_list=train_vars, max_to_keep=5, keep_checkpoint_every_n_hours=2) sess.run(tf.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint( os.path.join(CHECKPOINT_DIR, run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join('chatbot_model', 'trained_models', model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint( os.path.join('chatbot_model', 'trained_models', model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print(str(hvd.local_rank()), 'Loading checkpoint', ckpt) saver.restore(sess, ckpt) bcast.run() print(str(hvd.local_rank()), 'Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print(str(hvd.local_rank()), 'dataset has', data_sampler.total_size, 'tokens') print(str(hvd.local_rank()), 'Training...') counter = 1 if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'), 'r') as fp: counter = int(fp.read()) + 1 def save(): maketree(os.path.join(CHECKPOINT_DIR, run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, run_name, 'model-{}').format(counter)) saver.save(sess, os.path.join(CHECKPOINT_DIR, run_name, 'model'), global_step=counter) with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'), 'w') as fp: fp.write(str(counter) + '\n') def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run( tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) avg_loss = (0.0, 0.0) start_time = time.time() try: while True: batch = [data_sampler.sample(1024) for _ in range(batch_size)] _, lv = sess.run((train_op, loss), feed_dict={context: batch}) avg_loss = (avg_loss[0] * 0.99 + lv, avg_loss[1] * 0.99 + 1.0) if hvd.rank() == 0: if counter % save_every == 0: save() if counter % sample_every == 0: generate_samples() print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=time.time() - start_time, loss=lv, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') if hvd.rank() == 0: save()
def main(): enc = encoder.get_encoder(args.model_name) hparams = model.default_hparams() hparams.batch_size=args.batch_size hparams.seq_len=args.seq_len ##data_path args.train_data_path=args.data_dir+args.dataset+'/train.txt' args.eval_data_path=args.data_dir+args.dataset+'/dev.txt' args.test_data_path=args.data_dir+args.dataset+'/test.txt' args.eval_data_path=args.test_data_path ###Test mode only! args.gpt_save_path=args.gpt_save_dir+args.dataset+'/' args.dis_save_path=args.dis_save_dir+args.dataset+'/' args.gpt_sample_dir2=args.gpt_sample_dir+args.dataset+'/' args.dis_sample_dir2=args.dis_sample_dir+args.dataset+'/' args.log_path=args.log_dir+args.dataset+'/' maketree(args.gpt_save_dir) maketree(args.dis_save_dir) maketree(args.gpt_save_path) maketree(args.dis_save_path) maketree(args.gpt_sample_dir) maketree(args.dis_sample_dir) maketree(args.gpt_sample_dir2) maketree(args.dis_sample_dir2) maketree(args.log_dir) maketree(args.log_path) with open(os.path.join('models', args.model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if args.sample_length > hparams.n_ctx: raise ValueError( "Can't get samples longer than window size: %s" % hparams.n_ctx) if args.model_name == '345M': args.memory_saving_gradients = True if args.optimizer == 'adam': args.only_train_transformer_layers = True config = tf.ConfigProto() config.gpu_options.allow_growth = True config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF with tf.Session(config=config) as sess: scope_discri='distri' def get_dis_logit_and_prob_single_step(context, scope): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): context=tf.reshape(context, [-1, args.seq_len]) emb=tf.get_variable(name='emb', initializer=tf.random.normal([hparams.n_vocab, 32], 0, 0.02)) context_emb=tf.nn.embedding_lookup(emb, context) logit=dis(context_emb, scope=scope_discri) prob=tf.sigmoid(logit+1e-7) return logit, prob def get_dis_logit_and_prob(context, context_len, scope): ##Pay attention to context_len here. temporary changes!!!!!!!!!!!!!!!!!!! context_mask=(1-tf.sequence_mask(context_len-1, args.seq_len-1, dtype=tf.float32))*1e3 context_mask2=tf.sequence_mask(context_len-1, args.seq_len-1, dtype=tf.float32) ones=tf.ones(shape=[tf.shape(context_len)[0], args.seq_len], dtype=tf.int32)*enc.encoder['<|endoftext|>'] input_tensor_list=[] for i in range(1, args.seq_len): input_tensor_list.append(tf.concat([context[:, :i+1], ones[:,i+1:]], axis=1)) input_tensor=tf.concat(input_tensor_list, axis=0) log_prob, _=get_dis_logit_and_prob_single_step(input_tensor, scope=scope) log_prob=tf.transpose(tf.reshape(log_prob, [args.seq_len-1, -1])) log_prob+=tf.cast(context_mask, tf.float32) log_prob_min=tf.reduce_min(log_prob, axis=1) prob_min=tf.exp(log_prob_min) return log_prob_min, prob_min, log_prob ##Build discriminator def build_dis_layer(scope): context_pos_discri = tf.placeholder(tf.int32, [None, args.seq_len]) context_pos_discri_len = tf.placeholder(tf.int32, [None]) context_neg_discri = tf.placeholder(tf.int32, [None, args.seq_len]) context_neg_discri_len = tf.placeholder(tf.int32, [None]) label_pos_discri=tf.ones([tf.shape(context_pos_discri_len)[0]], dtype=tf.float32) label_neg_discri=tf.zeros([tf.shape(context_neg_discri_len)[0]], dtype=tf.float32) logit_pos_discri, prob_pos_discri, mask=get_dis_logit_and_prob(context_pos_discri, context_pos_discri_len, scope=scope) logit_neg_discri, _, _=get_dis_logit_and_prob(context_neg_discri, context_neg_discri_len, scope=scope) loss_pre_pos_discri=tf.nn.sigmoid_cross_entropy_with_logits(labels=label_pos_discri, logits=logit_pos_discri) loss_pos_discri=tf.reduce_mean(loss_pre_pos_discri) loss_pre_neg_discri=tf.nn.sigmoid_cross_entropy_with_logits(labels=label_neg_discri, logits=logit_neg_discri) loss_neg_discri=tf.reduce_mean(loss_pre_neg_discri) loss_discri=(loss_pos_discri*args.pos_loss_weight+loss_neg_discri)/(1+args.pos_loss_weight) train_var_list_discri=[x for x in tf.global_variables() if scope in x.name] train_op_discri=tf.train.AdamOptimizer().minimize(loss_discri, var_list=train_var_list_discri) var_list_discri=[x for x in tf.global_variables() if scope in x.name] initializer_discri=tf.variables_initializer(var_list_discri) saver_discri=tf.train.Saver(var_list=var_list_discri, max_to_keep=1) print('discri: {} build succeed!'.format(scope)) return context_pos_discri,context_pos_discri_len, context_neg_discri,context_neg_discri_len, loss_pos_discri, loss_neg_discri, loss_discri, train_op_discri, initializer_discri, saver_discri, prob_pos_discri, mask, logit_pos_discri class dis_class: def __init__(self, layer_num=1, scope=scope_discri): self.model=[] self.dis=np.zeros([layer_num], dtype=np.float32) print(layer_num) for i in range(layer_num): layer={'scope': scope+str(i)} layer['context_pos_discri'],layer['context_pos_discri_len'], layer['context_neg_discri'],layer['context_neg_discri_len'], layer['loss_pos_discri'], layer['loss_neg_discri'], layer['loss_discri'], layer['train_op_discri'], layer['initializer_discri'], layer['saver_discri'], layer['prob_pos_discri'], layer['mask'], layer['logit_pos_discri'] = build_dis_layer(scope+str(i)) self.model.append(layer) def prob(self, context, context_len, layer=-1): if layer==-1: layer=len(self.model) prob_final=tf.ones(tf.shape(context)[0], dtype=tf.float32) for i in range(layer): item=self.model[i] scope=item['scope'] _, prob, _=get_dis_logit_and_prob(context, context_len, scope=scope) prob_final*=prob return prob_final def log_prob_step(self, context, layer=-1): if layer==-1: layer=len(self.model) prob_final=tf.ones(tf.shape(context)[0], dtype=tf.float32) log_prob_list=[] for i in range(layer): item=self.model[i] scope=item['scope'] log_prob, prob=get_dis_logit_and_prob_single_step(context, scope=scope) log_prob_list.append(tf.expand_dims(log_prob, 1)) log_prob_final=tf.concat(log_prob_list, axis=1) return log_prob_final Dis=dis_class(layer_num=args.layer_num) context = tf.placeholder(tf.int32, [None, None]) context_len=tf.placeholder(tf.int32, [None]) context_mask=tf.sequence_mask(context_len-1, args.seq_len-1, dtype=tf.float32) context_in=context output = model.model(hparams=hparams, X=context_in) loss_tensor = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=context[:, 1:], logits=output['logits'][:, :-1])*context_mask loss=tf.reduce_sum(loss_tensor, axis=1)/(tf.reduce_sum(context_mask, axis=1)+1e-7) loss_sen=tf.reduce_sum(loss) loss=tf.reduce_mean(loss) if args.val_every > 0: def transform_np(x, lift=args.exponential_param): x=x-0.5 x=x+np.abs(x) return lift*x**2 def transform(x, lift=args.exponential_param): x=x-0.5 x=x+tf.abs(x) return lift*x**2 val_context = tf.placeholder(tf.int32, [args.val_batch_size, args.seq_len]) val_context_len=tf.placeholder(tf.int32, [args.batch_size]) NLL_bias=tf.placeholder(tf.float32, []) val_context_mask=tf.sequence_mask(val_context_len-1, args.seq_len-1, dtype=tf.float32) val_output = model.model(hparams=hparams, X=val_context) val_loss_tensor =tf.nn.sparse_softmax_cross_entropy_with_logits(labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])*val_context_mask val_context_prob_cut=Dis.prob(val_context, val_context_len) val_NLL_cut=tf.log(val_context_prob_cut+1e-7) val_loss=tf.reduce_sum(val_loss_tensor, axis=1)/(tf.reduce_sum(val_context_mask, axis=1)+1e-7) val_loss_cut=(tf.reduce_sum(val_loss_tensor, axis=1)+NLL_bias)/(tf.reduce_sum(val_context_mask, axis=1)+1e-7)-val_NLL_cut/tf.cast(val_context_len, tf.float32) val_loss_sum=tf.reduce_sum(val_loss_tensor, axis=1) val_loss_cut_sum=(tf.reduce_sum(val_loss_tensor, axis=1)+NLL_bias)-val_NLL_cut val_loss_mean=tf.reduce_mean(val_loss) val_loss_cut_mean=tf.reduce_mean(val_loss_cut) val_loss_summary = tf.summary.scalar('val_loss', val_loss_mean) tf_sample = sample.sample_sequence( hparams=hparams, length=args.seq_len, context=context, batch_size=args.batch_size, temperature=1.0, top_k=args.top_k, top_p=args.top_p, start_token=enc.encoder['<|endoftext|>']) start_token=enc.encoder['<|endoftext|>'] all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars if args.optimizer == 'adam': opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) elif args.optimizer == 'sgd': opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate) else: exit('Bad optimizer:', args.optimizer) if args.accumulate_gradients > 1: if args.memory_saving_gradients: exit("Memory saving gradients are not implemented for gradient accumulation yet.") opt = AccumulatingOptimizer( opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: if args.memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_lr = tf.summary.scalar('learning_rate', args.learning_rate) summaries = tf.summary.merge([summary_lr, summary_loss]) summary_log = tf.summary.FileWriter( os.path.join(CHECKPOINT_DIR, args.run_name)) saver = tf.train.Saver(var_list=all_vars, max_to_keep=1) sess.run(tf.global_variables_initializer()) if args.restore_from == 'latest': ckpt = tf.train.latest_checkpoint( os.path.join(CHECKPOINT_DIR, args.run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join('models', args.model_name)) elif args.restore_from == 'fresh': ckpt = tf.train.latest_checkpoint( os.path.join('models', args.model_name)) else: ckpt = tf.train.latest_checkpoint(args.restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') data_list, data_len = load_dataset(enc, args.train_data_path, args.seq_len) data_sampler = Sampler(data_list, data_len ) if args.val_every > 0: val_data_list, val_data_len = load_dataset(enc, args.eval_data_path, args.seq_len) print('dataset has', data_sampler.total_size, 'tokens') print('Training...') if args.val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_data_sampler = Sampler(val_data_list, val_data_len, seed=1) val_batches = [val_data_sampler.sample(args.batch_size) for _ in range(args.val_batch_count)] counter = 0 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) saver.save( sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') def train_step_discri(layer_id=0, mask_train_epoch=0): pos_samples, pos_samples_len=data_sampler.sample(args.batch_size) neg_samples=generate_negative_sample(layer_id=layer_id) neg_samples_len=get_array_len(neg_samples) _, loss=sess.run([Dis.model[layer_id]['train_op_discri'], Dis.model[layer_id]['loss_discri']], feed_dict={Dis.model[layer_id]['context_pos_discri']: pos_samples,Dis.model[layer_id]['context_pos_discri_len']: pos_samples_len, Dis.model[layer_id]['context_neg_discri']: neg_samples, Dis.model[layer_id]['context_neg_discri_len']: neg_samples_len}) return loss def generate_negative_samples(layer_id, generate_num=args.batch_size): result_list=[] generate_num_now=0 samples_mem=[] while generate_num_now<generate_num: t=time.time() sample_id=generate_negative_sample(layer_id=layer_id) samples=[] t1=time.time() selected_id_list=np.arange(len(sample_id)) t2=time.time() result_list.append(sample_id[selected_id_list]) generate_num_now+=len(selected_id_list) return np.concatenate(result_list, axis=0)[:generate_num] def get_array_len(sample_array): lens=[] for item in sample_array: for i in range(1, len(item)): if item[i]==enc.encoder['<|endoftext|>']: break lens.append(i) return np.array(lens).astype(np.int32) def generate_discri_sample3(layer_id=-1, sample_size=10000, save_path='/mnt/cephfs_new_wj/mlnlp/miaoning/Experiment/gpt-2-sep/samples/discri/sample2.txt'): samples=[] while len(samples)<sample_size: sample_id=generate_negative_sample(layer_id) for i in range(len(sample_id)): sample_tem=enc.decode(sample_id[i]).split('<|endoftext|>')[1].split('\n')[0] samples.append(sample_tem) print(len(samples)) with open(save_path, 'w') as g: g.write('\n'.join(samples)) def eval_discri_NLL(layer_id=0): losses_pos=[] losses_neg=[] for batch in tqdm.tqdm(val_batches): pos_samples, pos_samples_len=batch neg_samples=generate_negative_sample(layer_id=layer_id) neg_samples_len=get_array_len(neg_samples) loss_pos, mask=sess.run([Dis.model[layer_id]['loss_pos_discri'], Dis.model[layer_id]['mask']], feed_dict={Dis.model[layer_id]['context_pos_discri']: pos_samples, Dis.model[layer_id]['context_pos_discri_len']: pos_samples_len}) #print(mask) loss_neg=sess.run(Dis.model[layer_id]['loss_neg_discri'], feed_dict={Dis.model[layer_id]['context_neg_discri']: neg_samples, Dis.model[layer_id]['context_neg_discri_len']: neg_samples_len}) losses_pos.append(loss_pos) losses_neg.append(loss_neg) return np.mean(losses_pos), np.mean(losses_neg) def get_discri_quantile(layer_id=0, quantile=0.85): logits_list=[] for batch in tqdm.tqdm(val_batches): pos_samples, pos_samples_len=batch logits, mask=sess.run([Dis.model[layer_id]['logit_pos_discri'], Dis.model[layer_id]['mask']], feed_dict={Dis.model[layer_id]['context_pos_discri']: pos_samples, Dis.model[layer_id]['context_pos_discri_len']: pos_samples_len}) print(np.min(mask, axis=1)[:10]) print(logits[:10]) with open('mask.pkl', 'wb') as g: pkl.dump(mask, g) logits_list.extend(list(logits)) break with open('logits.pkl', 'wb') as g: pkl.dump(sorted(logits_list), g) #print(sorted(logits_list)) print('finish') return sorted(logits_list)[int(len(logits_list)*(1-quantile))] def train_discri(train_step, eval_every, train_layer_list=list(range(len(Dis.model)))): #sess.run(initializer_discri) print('Start Discri training') train_losses=[] for layer_id in train_layer_list: flag=0 for epoch in range(train_step): if epoch % eval_every==0: train_losses=np.mean(train_losses) train_losses=[] eval_NLL_pos, eval_NLL_neg=eval_discri_NLL(layer_id) eval_loss=(eval_NLL_pos*args.pos_loss_weight+eval_NLL_neg)/(args.pos_loss_weight+1) print('layer_id:{} discri eval loss:{}'.format(layer_id, eval_loss)) print('layer_id:{} discri NLL pos: {}, discri NLL neg: {}'.format(layer_id, eval_NLL_pos, eval_NLL_neg)) print(epoch) if epoch==0: eval_loss_old=eval_loss else: print(eval_loss, eval_loss_old) if eval_loss<eval_loss_old: eval_loss_old=eval_loss save_path=args.dis_save_path+str(layer_id)+'/' if not os.path.isdir(save_path): os.mkdir(save_path) Dis.model[layer_id]['saver_discri'].save(sess, save_path+'a') print('model discri saved!') flag=0 else: if epoch>=200: flag+=1 if flag>=4: break train_loss=train_step_discri(layer_id) print('layer_id:{} discri train loss:{}'.format(layer_id, train_loss)) train_losses.append(train_loss) return eval_loss_old tf_sample_0 = sample_link.sample_sequence( hparams=hparams, length=args.seq_len, context=context, batch_size=args.batch_size, temperature=1.0, top_k=args.top_k, top_p=args.top_p, start_token=enc.encoder['<|endoftext|>']) tf_sample_dict={} def generate_negative_sample(layer_id=0): ##output the filtered result of layer layer_id-1 if layer_id==0: tf_sample=tf_sample_0 sample = data_sampler.sample(args.batch_size)[0][:,0:1] out = sess.run( tf_sample, feed_dict={context: sample})[:,:args.seq_len] for i in range(len(out)): flag=0 for j in range(len(out[i])): if flag==2: out[i][j]=start_token continue if out[i][j]==start_token: flag+=1 return out else: if layer_id==-1: layer_id=len(Dis.model) if layer_id in tf_sample_dict: tf_sample=tf_sample_dict[layer_id] else: tf_sample = sample_link.sample_sequence_ISMC_threshold( Dis=Dis, layer=layer_id, hparams=hparams, length=args.seq_len, context=context, batch_size=args.batch_size, temperature=1.0, top_k=args.top_k, top_p=args.top_p, start_token=enc.encoder['<|endoftext|>']) tf_sample_dict[layer_id]=tf_sample sample = data_sampler.sample(args.batch_size)[0][:,0:1] out = sess.run( tf_sample, feed_dict={context: sample})[:,:args.seq_len] for i in range(len(out)): flag=0 for j in range(len(out[i])): if flag==2: out[i][j]=start_token continue if out[i][j]==start_token: flag+=1 return out def validation(): print('Calculating validation loss...') start_time=time.time() losses = [] rates=[] for batch in tqdm.tqdm(val_batches): losses.append(sess.run(val_loss_mean, feed_dict={val_context: batch[0], val_context_len: batch[1]})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss_mean: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print( '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_val_loss)) return v_val_loss def validation_cut(NLL_bias_0=0): print('Calculating validation loss...') losses = [] rates=[] for batch in tqdm.tqdm(val_batches): losses.append(sess.run(val_loss_cut_mean, feed_dict={val_context: batch[0], val_context_len: batch[1], NLL_bias:NLL_bias_0})) v_val_loss = np.mean(losses) print( '[{counter} | {time:2.2f}] validation cut loss = {loss:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_val_loss)) return v_val_loss def sample_batch(): return [data_sampler.sample(1024) for _ in range(args.batch_size)] def train_gpt(): val_loss_old=10000.0 avg_loss = (0.0, 0.0) start_time = time.time() counter=0 while True: #pretraining if counter % args.save_every == 0: pass #save() if counter % args.sample_every == 0: pass #generate_samples() if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1): val_loss_1=validation() print(str(counter //args.val_every)) if val_loss_1>=val_loss_old: print('pre-training ends!') break else: val_loss_old=val_loss_1 saver.save(sess, args.gpt_save_path+'a') print('save succeed!') if args.accumulate_gradients > 1: sess.run(opt_reset) for _ in range(args.accumulate_gradients): batch, batch_len=data_sampler.sample(args.batch_size) sess.run( opt_compute, feed_dict={context: batch, context_len:batch_len}) (v_loss, v_summary) = sess.run((opt_apply, summaries)) else: batch, batch_len=data_sampler.sample(args.batch_size) (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summaries), feed_dict={context: batch, context_len:batch_len}) summary_log.add_summary(v_summary, counter) avg_loss = (avg_loss[0] * 0.9 + v_loss, avg_loss[1] * 0.9 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 class log_writer: def __init__(self, path): self.path=path with open(path, 'w') as g: g.write('') def __call__(self, string, verbose=False): with open(self.path, 'a') as g: g.write(string+'\n') if verbose: print(string) try: if args.finetune: #Finetune GPT-2 train_gpt() if True: #Restore Finetuned model save_path=tf.train.latest_checkpoint(args.gpt_save_path) saver.restore(sess, save_path) print('Load gpt2 succeeded!') if args.evaluate_finetune: #Evaluate finetuning baseline print(validation()) if args.evaluate_finetune: #Calculate reverse-ppl for finetuning baseline sample_path=args.gpt_sample_dir2+'sample.txt' generate_discri_sample3(layer_id=0, sample_size=3000, save_path=sample_path) rev_ppl=train.file_f(train_data_path=sample_path, val_data_path=args.eval_data_path) Log_writer=log_writer(args.log_path+'finetune') Log_writer('finetuning_rev_ppl: {}'.format(rev_ppl), verbose=True) ##Begin tailoring if True: Log_writer=log_writer(args.log_path+'discri') for layer in range(args.layer_num): print(layer) if args.train_tailor: #Train ratio estimator train_discri(500, 10, [layer]) if True: #Restore ratio estimator for layer_id in range(layer+1): save_path=args.dis_save_path+str(layer_id)+'/' print(save_path) save_path=tf.train.latest_checkpoint(save_path) print(save_path) Dis.model[layer_id]['saver_discri'].restore(sess, save_path) if False: #Save quantile for analysis with open(args.dis_sample_dir2+'quantile.pkl', 'rb') as f: pkl.load(f) print('Load dis model succeeded!') if True: if layer==0: quantile=0.85 else: quantile=0.9 Dis.dis[layer]=get_discri_quantile(layer, quantile) with open(args.dis_sample_dir2+'quantile.pkl', 'wb') as g: pkl.dump(Dis.dis, g) print(Dis.dis) if args.evaluate_tailor: #Generate sample for ERS and calculate reverse-ppl sample_path=args.dis_sample_dir2+'_sample_layer_'+str(layer) generate_discri_sample3(layer_id=layer+1, sample_size=3000, save_path=sample_path) rev_ppl=train.file_f(train_data_path=sample_path, val_data_path=args.eval_data_path) Log_writer('layer: {}, dis_rev_ppl: {}'.format(layer, rev_ppl), verbose=True) except KeyboardInterrupt: print('interrupted')
def main(): args = parser.parse_args() folder_id = get_id(args.gdir) #xmpp = SendMsgBot(jid, password, to, "Starting GPT-2") #xmpp.register_plugin('xep_0030') # Service Discovery #xmpp.register_plugin('xep_0199') # XMPP Ping #xmpp.connect() #threading = Thread(target=xmpp.process, daemon=True).start() download_checkpoint(folder_id) #send_m('checkpoint downloaded') enc = encoder.get_encoder(args.model_name) hparams = model.default_hparams() with open(os.path.join('models', args.model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if args.sample_length > hparams.n_ctx: raise ValueError( "Can't get samples longer than window size: %s" % hparams.n_ctx) if args.model_name == '345M': args.memory_saving_gradients = True # if args.optimizer == 'adam': # args.only_train_transformer_layers = True config = tf.ConfigProto() config.gpu_options.allow_growth = True config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF with tf.Session(config=config) as sess: context = tf.placeholder(tf.int32, [args.batch_size, None]) context_in = randomize(context, hparams, args.noise) output = model.model(hparams=hparams, X=context_in) loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=context[:, 1:], logits=output['logits'][:, :-1])) if args.val_every > 0: val_context = tf.placeholder(tf.int32, [args.val_batch_size, None]) val_output = model.model(hparams=hparams, X=val_context) val_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) val_loss_summary = tf.summary.scalar('val_loss', val_loss) tf_sample = sample.sample_sequence( hparams=hparams, length=args.sample_length, context=context, batch_size=args.batch_size, temperature=1.0, top_k=args.top_k, top_p=args.top_p) all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars if args.optimizer == 'adam': opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) elif args.optimizer == 'sgd': opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate) else: exit('Bad optimizer:', args.optimizer) if args.accumulate_gradients > 1: if args.memory_saving_gradients: exit("Memory saving gradients are not implemented for gradient accumulation yet.") opt = AccumulatingOptimizer( opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: if args.memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_lr = tf.summary.scalar('learning_rate', args.learning_rate) summaries = tf.summary.merge([summary_lr, summary_loss]) summary_log = tf.summary.FileWriter(os.path.join(CHECKPOINT_DIR, args.run_name)) saver = tf.train.Saver(var_list=all_vars, max_to_keep=5, keep_checkpoint_every_n_hours=2) sess.run(tf.global_variables_initializer()) if args.restore_from == 'latest': ckpt = tf.train.latest_checkpoint(os.path.join(CHECKPOINT_DIR, args.run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint(os.path.join('models', args.model_name)) elif args.restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join('models', args.model_name)) else: ckpt = tf.train.latest_checkpoint(args.restore_from) print('Loading checkpoint', ckpt) #send_m('Loading ' + str(ckpt)) saver.restore(sess, ckpt) print('Loading dataset...') #send_m('Loading dataset...') #chunks = load_dataset(enc, args.dataset, args.combine) ds_path = f'{CHECKPOINT_DIR}//run1//{args.dataset}' chunks = load_dataset(enc, ds_path, args.combine) data_sampler = Sampler(chunks) print(f'{ds_path} has', data_sampler.total_size, 'tokens') if args.val_every > 0: val_chunks = load_dataset(enc, args.val_dataset, args.combine) if args.val_dataset else chunks if args.enc: print(colored(f'Trying writing Data.npz encoded from this dataset to {args.enc}', 'red')) np.savez_compressed(args.enc, *chunks) upload_npz(args.enc, folder_id) #send_m(f'{args.dataset} has ' + str(data_sampler.total_size) + ' tokens' + ' Start training...') print('Training...') if args.val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_data_sampler = Sampler(val_chunks, seed=1) val_batches = [[val_data_sampler.sample(1024) for _ in range(args.val_batch_size)] for _ in range(args.val_batch_count)] counter = 1 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) saver.save( sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') save_gdisk(counter, folder_id) def generate_samples(): print('Generating samples...') #send_m('Generating samples...') context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < args.sample_num: out = sess.run( tf_sample, feed_dict={context: args.batch_size * [context_tokens]}) for i in range(min(args.sample_num - index, args.batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) #send_m(text) maketree(os.path.join(SAMPLE_DIR, args.run_name)) with open( os.path.join(SAMPLE_DIR, args.run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) def validation(): print('Calculating validation loss...') losses = [] for batch in tqdm.tqdm(val_batches): losses.append(sess.run(val_loss, feed_dict={val_context: batch})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print( '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_val_loss)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(args.batch_size)] avg_loss = (0.0, 0.0) start_time = time.time() last_time = time.time() cur_counter, min_loss = 1, 2.0 print(colored(f'Model >>> {args.gdir}\nLearning rate is {args.learning_rate}', 'blue')) print(colored(f'model optimizer >>> {args.optimizer}\nRestricted to train only transformer layer={args.only_train_transformer_layers}', 'blue')) #send_m(f'Model >>> {args.model_name}\nLearning rate is {args.learning_rate}') try: while True: if counter % args.save_every == 0: save() if check_quota(): return() # exit train if counter % args.sample_every == 0: generate_samples() if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1): validation() if args.accumulate_gradients > 1: sess.run(opt_reset) for _ in range(args.accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summaries)) else: (_, v_loss, v_summary) = sess.run((opt_apply, loss, summaries), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) a_loss = avg_loss[0] / avg_loss[1] time_all = int((time.time() - start_time) / 60) time_iter = time.time() - last_time stats = f'[{counter} | {cur_counter} | {time_all}m | {time_iter:2.2f}s] loss={v_loss:2.2f} avg={a_loss:2.2f}' if not(cur_counter % 50): print(colored(stats, 'red' if a_loss > min_loss else 'yellow')) if a_loss < min_loss: min_loss = a_loss #send_m(stats) last_time = time.time() counter += 1 cur_counter += 1 except Exception as e: #send_m('Stoped ' + str(e.__class__)) print('Stoped', e.__class__)
def main(): args = parser.parse_args() enc = encoder.get_encoder(args.model_name, models_dir=args.models_dir) hparams = model.default_hparams() with open(os.path.join('models', args.model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if args.sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) with tf.Session() as sess: # Fully static shape required to make memory accounting in # twremat accurate. train_context = tf.placeholder(tf.int32, [args.batch_size, 1024]) train_context_in = randomize(train_context, hparams, args.noise) train_output = model.model(hparams=hparams, X=train_context_in) train_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=train_context[:, 1:], logits=train_output['logits'][:, :-1])) if args.val_every > 0: val_context = tf.placeholder(tf.int32, [args.val_batch_size, None]) val_output = model.model(hparams=hparams, X=val_context) val_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) val_loss_summary = tf.summary.scalar('val_loss', val_loss) sample_context = tf.placeholder(tf.int32, [args.batch_size, None]) tf_sample = sample.sample_sequence(hparams=hparams, length=args.sample_length, context=sample_context, batch_size=args.batch_size, temperature=1.0, top_k=args.top_k, top_p=args.top_p) all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name ] if args.only_train_transformer_layers else all_vars if args.optimizer == 'adam': print('Using Adam optimizer', file=sys.stderr) opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) elif args.optimizer == 'sgd': print('Using SGD optimizer', file=sys.stderr) opt = tf.train.GradientDescentOptimizer( learning_rate=args.learning_rate) else: exit('Bad optimizer:', args.optimizer) if args.memory_saving_gradients: if tf.VERSION >= '2': exit( 'Memory saving gradients are not supported in tensorflow 2.x' ) import memory_saving_gradients opt_grads = memory_saving_gradients.gradients( train_loss, train_vars) elif args.twremat: import tfremat opt_grads = tf.gradients(train_loss, train_vars) (train_loss, opt_grads) = tfremat.tf_remat( (train_loss, opt_grads), memlimit=args.twremat_memlimit) else: opt_grads = tf.gradients(train_loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', train_loss) # if args.twremat: # import tfremat # # Applying tfremat to opt_apply has more accurate # # accounting but is a bit iffier since side effecting ops # # have more restrictions for correctness. If in doubt # # revert back to version using opt_grads above. # (opt_apply, train_loss, summary_loss) = ( # tfremat.tf_remat((opt_apply, train_loss, summary_loss), memlimit=args.twremat_memlimit)) summary_lr = tf.summary.scalar('learning_rate', args.learning_rate) summaries = tf.summary.merge([summary_lr, summary_loss]) summary_log = tf.summary.FileWriter( os.path.join(CHECKPOINT_DIR, args.run_name)) saver = tf.train.Saver(var_list=all_vars, max_to_keep=5, keep_checkpoint_every_n_hours=2) sess.run(tf.global_variables_initializer()) if args.restore_from == 'latest': ckpt = tf.train.latest_checkpoint( os.path.join(CHECKPOINT_DIR, args.run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join('models', args.model_name)) elif args.restore_from == 'fresh': ckpt = tf.train.latest_checkpoint( os.path.join('models', args.model_name)) else: ckpt = tf.train.latest_checkpoint(args.restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, args.dataset, args.combine, encoding=args.encoding) data_sampler = Sampler(chunks) if args.val_every > 0: if args.val_dataset: val_chunks = load_dataset(enc, args.val_dataset, args.combine, encoding=args.encoding) else: val_chunks = chunks print('dataset has', data_sampler.total_size, 'tokens') print('Training...') if args.val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_data_sampler = Sampler(val_chunks, seed=1) val_batches = [[ val_data_sampler.sample(1024) for _ in range(args.val_batch_size) ] for _ in range(args.val_batch_count)] counter = 1 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) saver.save(sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') def generate_samples(): print('Generating samples...') context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < args.sample_num: out = sess.run(tf_sample, feed_dict={ sample_context: args.batch_size * [context_tokens] }) for i in range(min(args.sample_num - index, args.batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, args.run_name)) with open(os.path.join(SAMPLE_DIR, args.run_name, 'samples-{}').format(counter), 'w', encoding=args.encoding) as fp: fp.write('\n'.join(all_text)) def validation(): print('Calculating validation loss...') losses = [] for batch in tqdm.tqdm(val_batches): losses.append( sess.run(val_loss, feed_dict={val_context: batch})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'. format(counter=counter, time=time.time() - start_time, loss=v_val_loss)) def sample_batch(): return [data_sampler.sample(1024) for _ in range(args.batch_size)] avg_loss = (0.0, 0.0) start_time = time.time() # print('Evaluating grads..') # tf2.profiler.experimental.start('logdir') # sess.run((opt_apply, train_loss, summaries), feed_dict={train_context: sample_batch()}) # tf2.profiler.experimental.stop() # print('Succeeded') # exit() try: while True: if counter % args.save_every == 0: save() if counter % args.sample_every == 0: generate_samples() if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1): validation() (_, v_loss, v_summary) = sess.run( (opt_apply, train_loss, summaries), feed_dict={train_context: sample_batch()}) summary_log.add_summary(v_summary, counter) avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save()
def finetune(sess, dataset, steps=-1, model_name='124M', model_dir='models', combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from='latest', run_name='run1', checkpoint_dir='checkpoint', sample_every=100, sample_length=1023, sample_num=1, multi_gpu=False, save_every=1000, print_every=1, max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, optimizer='adam', overwrite=False, val_dataset=None, val_batch_size=2, val_batch_count=40, val_every=0): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ # assert model_name not in ['774M', '1558M'] or multi_gpu, "Currently, a modern single GPU cannot finetune the 774M GPT-2 model or larger." SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(checkpoint_dir, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) files = [f for f in os.listdir(checkpoint_path)] for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: try: shutil.copyfile(os.path.join(model_dir, model_name, file), os.path.join(checkpoint_path, file)) except FileNotFoundError as fnf_error: print( "You need to download the GPT-2 model first via download_gpt2()" ) raise (fnf_error) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) if model_name not in ['117M', '124M']: use_memory_saving_gradients = True only_train_transformer_layers = True accumulate_gradients = 1 context = tf.compat.v1.placeholder(tf.int32, [batch_size, None]) gpus = [] if multi_gpu: gpus = get_available_gpus() output = model.model(hparams=hparams, X=context, gpus=gpus) loss = tf.reduce_mean( input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) # validation code if val_every > 0: val_context = tf.placeholder(tf.int32, [val_batch_size, None]) val_output = model.model(hparams=hparams, X=val_context, reuse=True) # added reuse=True val_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) val_loss_summary = tf.summary.scalar('val_loss', val_loss) tf_sample = sample.sample_sequence(hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40) all_vars = [ v for v in tf.compat.v1.trainable_variables() if 'model' in v.name ] train_vars = [v for v in all_vars if '/h' in v.name ] if only_train_transformer_layers else all_vars if optimizer == 'adam': opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) elif optimizer == 'sgd': opt = tf.compat.v1.train.GradientDescentOptimizer( learning_rate=learning_rate) if accumulate_gradients > 1: if use_memory_saving_gradients: exit( "Memory saving gradients are not implemented for gradient accumulation yet." ) opt = AccumulatingOptimizer(opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply) else: if use_memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(ys=loss, xs=train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.compat.v1.summary.scalar('loss', loss) summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path) saver = tf.compat.v1.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints) sess.run(tf.compat.v1.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join(model_dir, model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) # validation code if val_every > 0: if val_dataset: val_chunks = load_dataset(enc, val_dataset, combine) else: val_chunks = chunks print('dataset has', data_sampler.total_size, 'tokens') print('Training...') # validation code if val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_data_sampler = Sampler(val_chunks, seed=1) val_batches = [[ val_data_sampler.sample(1024) for _ in range(val_batch_size) ] for _ in range(val_batch_count)] counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 counter_base = counter def save(): maketree(checkpoint_path) print('Saving', os.path.join(checkpoint_path, 'model-{}').format(counter - 1)) saver.save(sess, os.path.join(checkpoint_path, 'model'), global_step=counter - 1) with open(counter_path, 'w') as fp: fp.write(str(counter - 1) + '\n') def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run(tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) # validation code def validation(): print('Calculating validation loss...') losses = [] for batch in tqdm(val_batches): losses.append(sess.run(val_loss, feed_dict={val_context: batch})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.format( counter=counter, time=time.time() - start_time, loss=v_val_loss)) return v_val_loss def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] if overwrite and restore_from == 'latest': for file in files: if file.startswith('model') or file.startswith('events'): os.remove(os.path.join(checkpoint_path, file)) save() avg_loss = (0.0, 0.0) start_time = time.time() #Trying out a change to finetune that saves only when validation loss decreases if steps: steps = int(steps) try: while True: if steps > 0 and counter == (counter_base + steps): #save() return # if (counter - 1) % save_every == 0 and counter > 1: # save() if (counter - 1) % sample_every == 0 and counter > 1: generate_samples() # validation code if val_every > 0 and counter == 1: v_val_loss = validation() save() elif val_every > 0 and counter == counter_base: v_val_loss = validation() elif val_every > 0 and (counter % val_every == 0): new_v_val_loss = validation() if new_v_val_loss < v_val_loss: v_val_loss = new_v_val_loss save() if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): sess.run(opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) if (counter % print_every == 0) or counter == 1: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save()
def main(): args = parser.parse_args() enc = encoder.get_encoder(args.model_name) hparams = model.default_hparams() with open(os.path.join('models', args.model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if args.sample_length > hparams.n_ctx: raise ValueError( "Can't get samples longer than window size: %s" % hparams.n_ctx) if args.model_name == '774M': args.memory_saving_gradients = True if args.optimizer == 'adam': args.only_train_transformer_layers = True config = tf.ConfigProto() config.gpu_options.allow_growth = True config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF with tf.Session(config=config) as sess: context = tf.placeholder(tf.int32, [args.batch_size, None]) context_in = randomize(context, hparams, args.noise) output = model.model(hparams=hparams, X=context_in) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) tf_sample = sample.sample_sequence( hparams=hparams, length=args.sample_length, context=context, batch_size=args.batch_size, temperature=1.0, top_k=args.top_k, top_p=args.top_p) all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] #this line is to hopefully reduce memory usage (found on Twitter: https://twitter.com/BasedBlue/status/1169601983046672385?s=20) edgeindex = -1 * args.layers_to_train train_vars = all_vars[edgeindex:] print("Training", args.layers_to_train, "raw layers out of", len(all_vars)) train_vars = [v for v in train_vars if '/h' in v.name] if args.only_train_transformer_layers else train_vars print("Training", len(train_vars), "net layers out of", len(all_vars)) if args.optimizer == 'adam': opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) elif args.optimizer == 'sgd': opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate) elif args.optimizer == 'adafactor': opt = AdafactorOptimizer(learning_rate=args.learning_rate) else: exit('Bad optimizer:', args.optimizer) if args.accumulate_gradients > 1: if args.memory_saving_gradients: exit("Memory saving gradients are not implemented for gradient accumulation yet.") opt = AccumulatingOptimizer( opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: if args.memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_lr = tf.summary.scalar('learning_rate', args.learning_rate) summaries = tf.summary.merge([summary_lr, summary_loss]) summary_log = tf.summary.FileWriter( os.path.join(CHECKPOINT_DIR, args.run_name)) saver = tf.train.Saver( var_list=all_vars, max_to_keep=5, keep_checkpoint_every_n_hours=2) sess.run(tf.global_variables_initializer()) if args.restore_from == 'latest': ckpt = tf.train.latest_checkpoint( os.path.join(CHECKPOINT_DIR, args.run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join('models', args.model_name)) elif args.restore_from == 'fresh': ckpt = tf.train.latest_checkpoint( os.path.join('models', args.model_name)) else: ckpt = tf.train.latest_checkpoint(args.restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) print('Loading dataset...') chunks = load_dataset(enc, args.dataset, args.combine, encoding=args.encoding) data_sampler = Sampler(chunks) if args.val_every > 0: if args.val_dataset: val_chunks = load_dataset(enc, args.val_dataset, args.combine, encoding=args.encoding) else: val_chunks = chunks print('dataset has', data_sampler.total_size, 'tokens') print('Training...') if args.val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_data_sampler = Sampler(val_chunks, seed=1) val_batches = [[val_data_sampler.sample(1024) for _ in range(args.val_batch_size)] for _ in range(args.val_batch_count)] counter = 1 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) saver.save( sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') def generate_samples(): print('Generating samples...') context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < args.sample_num: out = sess.run( tf_sample, feed_dict={context: args.batch_size * [context_tokens]}) for i in range(min(args.sample_num - index, args.batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, args.run_name)) with open( os.path.join(SAMPLE_DIR, args.run_name, 'samples-{}').format(counter), 'w', encoding=args.encoding) as fp: fp.write('\n'.join(all_text)) def sample_batch(): ret = [data_sampler.sample(1024) for _ in range(args.batch_size)] # print (enc.decode(ret[0])) return ret avg_loss = (0.0, 0.0) bval_loss = (0.0, 0.0) start_time = time.time() best_val_loss = 99 missed_val_checkpoints = 0 try: while counter < args.stop_after: if counter % args.sample_every == 0: generate_samples() if args.accumulate_gradients > 1: sess.run(opt_reset) for _ in range(args.accumulate_gradients): sess.run( opt_compute, feed_dict={context: sample_batch()}) (v_loss, v_summary) = sess.run((opt_apply, summaries)) else: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summaries), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) avg_loss = (avg_loss[0] * 0.98 + v_loss, avg_loss[1] * 0.98 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) if args.val_every > 0 and counter % args.val_every == 0: valbatch = [val_data_sampler.sample(1024) for _ in range(args.batch_size)] valacc = sess.run(loss, feed_dict={context: valbatch}) bval_loss = (bval_loss[0] * 0.9 + valacc, bval_loss[1] * 0.9 + 1.0) av_val_loss = bval_loss[0] / bval_loss[1] av_train_loss = avg_loss[0] / avg_loss[1] print( '[{counter} | {time:2.2f}] VAL_loss={loss:2.4f} VAL_avg={avg:2.4f} best={best:2.4f}' .format( counter=counter, time=time.time() - start_time, loss=valacc, avg=av_val_loss, best=best_val_loss)) if counter >= args.save_every and counter % args.save_every == 0: # check for validation checkpoints every save_every iterations. if av_val_loss < best_val_loss and av_val_loss > av_train_loss: # got a good one from validation, save a checkpoint (every save_every) -- but don't save before val loss goes above train loss save() best_val_loss = av_val_loss missed_val_checkpoints = 0 else: # missed a validation checkpoint. tolerate like 10 of these. if av_val_loss > av_train_loss: # don't count a missed checkpoint while val loss is under training loss missed_val_checkpoints += 1 if missed_val_checkpoints > 19: # missed too many save opportunities, stop training counter = args.stop_after + 1 print('stopping training due to val loss not improving.') counter += 1 except KeyboardInterrupt: print('interrupted')