def __init__(self, wv_file='py_files/saved_objects/poetic_embeddings.300d.txt', syllables_file='py_files/saved_objects/cmudict-0.7b.txt', postag_file='py_files/saved_objects/postag_dict_all.p', model_dir='py_files/models/all_combined_back'): self.api_url = 'https://api.datamuse.com/words' self.ps = nltk.stem.PorterStemmer() self.punct = re.compile(r'[^\w\s]') self.model_dir = model_dir # self.poetic_vectors = KeyedVectors.load_word2vec_format(wv_file, binary=False) self.create_syll_dict(syllables_file) with open(postag_file, 'rb') as f: postag_dict = pickle.load(f) self.pos_to_words = postag_dict[1] self.words_to_pos = postag_dict[2] self.create_pos_syllables() self.create_templates_dict(postag_dict[0]) self.first_line_words = pickle.load( open('py_files/saved_objects/first_line.p', 'rb')) self.width = 20 # Not sure what this does, necessary for search_back function self.word_pools = [set([]) for n in range(4)] self.enc = get_encoder('117M')
def __init__(self, sess, length=80, temperature=0.9, top_k=40): seed = None batch_size = 1 model_path = 'gpt2/models/117M' self.sess = sess self.enc = encoder.get_encoder(model_path) hparams = model.default_hparams() with open(os.path.join(model_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) self.context = tf.placeholder(tf.int32, [batch_size, None]) np.random.seed(seed) tf.set_random_seed(seed) self.output = sample.sample_sequence( hparams=hparams, length=length, context=self.context, batch_size=batch_size, ) saver = tf.train.Saver() ckpt = tf.train.latest_checkpoint(model_path) saver.restore(self.sess, ckpt)
def __init__(self, syllables_file='LimGen/saved_objects/cmudict-0.7b.txt', postag_file='LimGen/saved_objects/postag_dict_all.p', model_name='345M'): with open('LimGen/saved_objects/total_vocab.pickle', "rb") as f: self.total_vocab = pickle.load(f) with open('LimGen/saved_objects/clean_rhyming_dictionary.pickle', "rb") as f: self.rhyming_dictionary = pickle.load(f) self.api_url = 'https://api.datamuse.com/words' self.ps = nltk.stem.PorterStemmer() self.punct = re.compile(r'[^\w\s]') # punctuations self.punctuation = { "second": True, "third": True, "fourth": True, "fifth": True } self.sentence_to_punctuation = { "second": ".", "third": ",", "fourth": ",", "fifth": "." } # gpt2 model self.model_name = model_name self.enc = get_encoder(self.model_name) # load spacy word embeddings self.spacy_nlp = spacy.load("en_core_web_lg") # specify parameters for look ahead score self.word_embedding_alpha = 0.5 self.word_embedding_coefficient = 0 # for multiprocessing self.cpu = mp.cpu_count() # create variables for hard constraints, syllable, meter and rhyme, people names. self.create_syll_dict(syllables_file) with open(postag_file, 'rb') as f: postag_dict = pickle.load(f) self.pos_to_words = postag_dict[1] self.words_to_pos = postag_dict[2] self.create_pos_syllables() self.create_templates_dict(postag_dict[0]) self.filtered_names_rhymes = "LimGen/saved_objects/filtered_names_rhymes.pkl" with open(self.filtered_names_rhymes, "rb") as hf: self.names_rhymes_list = pickle.load(hf) self.female_name_list, self.male_name_list = pickle.load( open("LimGen/saved_objects/name_list.p", "rb")) # filtering out unfavorable words with open("LimGen/saved_objects/filtered_nouns_verbs.txt", "r") as hf: self.filtered_nouns_verbs = [ line.strip() for line in hf.readlines() ] self.filtered_nouns_verbs += self.pos_to_words[ "IN"] + self.pos_to_words["PRP"] self.verb_repeat_whitelist = set([ 'be', 'is', 'am', 'are', 'was', 'were', 'being', 'do', 'does', 'did', 'have', 'has', 'had' ])
def finetune(sess, dataset, steps=-1, model_name='117M', combine=50000, batch_size=1, learning_rate=0.0001, accumulate_gradients=5, restore_from='latest', run_name='run1', sample_every=100, sample_length=1023, sample_num=1, save_every=1000, print_every=1, max_checkpoints=1, model_load=False): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. See that file for parameter definitions. """ CHECKPOINT_DIR = 'checkpoint' SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name) def maketree(path): try: os.makedirs(path) except: pass maketree(checkpoint_path) if not model_load: for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: shutil.copyfile(os.path.join('models', model_name, file), os.path.join(checkpoint_path, file)) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if sample_length > hparams.n_ctx: raise ValueError( "Can't get samples longer than window size: %s" % hparams.n_ctx) context = tf.placeholder(tf.int32, [batch_size, None]) loss_mask = tf.placeholder(tf.int8, [batch_size, None]) output = model.model(hparams=hparams, X=context) loss_mask_float = tf.cast(loss_mask, tf.float32) # with loss mask -- reduce mean loss = tf.reduce_mean( loss_mask_float[:, :-1] * tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) ''' # with loss mask -- reduce sum / reduce sum loss = tf.reduce_sum( loss_mask_float[:, :-1] * tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) / tf.reduce_sum( loss_mask_float[:, :-1] ) ''' ''' # without loss mask loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) ''' tf_sample = sample.sample_sequence( hparams=hparams, length=sample_length, context=context, batch_size=batch_size, temperature=1.0, top_k=40) train_vars = [v for v in tf.trainable_variables() if 'model' in v.name] if accumulate_gradients > 1: opt = AccumulatingOptimizer( opt=tf.train.AdamOptimizer(learning_rate=learning_rate), var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: opt_apply = tf.train.AdamOptimizer( learning_rate=learning_rate).minimize( loss, var_list=train_vars) summary_loss = tf.summary.scalar('loss', loss) summary_log = tf.summary.FileWriter(checkpoint_path) saver = tf.train.Saver( var_list=train_vars, max_to_keep=max_checkpoints) sess.run(tf.global_variables_initializer()) if restore_from == 'latest': ckpt = tf.train.latest_checkpoint(checkpoint_path) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tf.train.latest_checkpoint( os.path.join('models', model_name)) elif restore_from == 'fresh': ckpt = tf.train.latest_checkpoint( os.path.join('models', model_name)) else: ckpt = tf.train.latest_checkpoint(restore_from) print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) if model_load: return print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) print('dataset has', data_sampler.total_size, 'tokens') print('Training...') counter = 1 counter_path = os.path.join(checkpoint_path, 'counter') if os.path.exists(counter_path) and restore_from == 'latest': # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 counter_base = counter def save(): maketree(checkpoint_path) print( 'Saving', os.path.join(checkpoint_path, 'model-{}').format(counter-1)) saver.save( sess, os.path.join(checkpoint_path, 'model'), global_step=counter-1) with open(counter_path, 'w') as fp: fp.write(str(counter-1) + '\n') def generate_samples(): context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < sample_num: out = sess.run( tf_sample, feed_dict={context: batch_size * [context_tokens]}) for i in range(min(sample_num - index, batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) all_text.append(text) index += 1 print(text) maketree(os.path.join(SAMPLE_DIR, run_name)) with open( os.path.join(SAMPLE_DIR, run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) def sample_batch(): sampled_batch = [data_sampler.sample(1024) for _ in range(batch_size)] batch_len = min(1024, max([len(v) for v in sampled_batch])) batch_masks = np.zeros([batch_size, batch_len]) ''' sampled_batch = [data_sampler.sample(1024) for _ in range(batch_size)] batch_len = min(1024, max([len(v) for v in sampled_batch])) batch_masks = np.zeros([batch_size, batch_len]) for i, v in enumerate(sampled_batch): if len(v) > batch_len: sampled_batch[i] = v[-batch_len:] mask_start = len(v) - list(v[::-1]).index(63) + 1 # batch_masks[i,mask_start:len(v)] += 1 # without padding after endoftext batch_masks[i, mask_start:] += 1 # with padding after endoftext if batch_size > 1: sampled_batch = np.asarray([ np.pad(v, [0, batch_len-len(v)], 'constant', constant_values=63) for v in sampled_batch ], dtype=np.int32) ''' ''' if batch_len > 1024: sampled_batch = sampled_batch[:,-1024:] ''' return sampled_batch, batch_masks avg_loss = (0.0, 0.0) start_time = time.time() try: while True: if steps > 0 and counter == (counter_base + steps): save() return if counter % save_every == 0: save() if counter % sample_every == 0: generate_samples() if accumulate_gradients > 1: sess.run(opt_reset) for _ in range(accumulate_gradients): context_t, loss_mask_t = sample_batch() sess.run( opt_compute, feed_dict={context: context_t, loss_mask: loss_mask_t}) (v_loss, v_summary) = sess.run((opt_apply, summary_loss)) else: raise NotImplementedError() (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}) summary_log.add_summary(v_summary, counter) if counter % print_every == 0: avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save()
def generate(sess, return_as_list=False, truncate=None, destination_path=None, sample_delim='=' * 20 + '\n', prefix=None, model_name='117M', seed=None, nsamples=1, batch_size=1, length=1023, temperature=0.7, top_k=0, run_name='run1', include_prefix=True): """Generates text from a model loaded into memory. Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py """ if batch_size is None: batch_size = 1 assert nsamples % batch_size == 0 if nsamples == 1: sample_delim = '' if prefix: context = tf.placeholder(tf.int32, [batch_size, None]) CHECKPOINT_DIR = 'checkpoint' SAMPLE_DIR = 'samples' checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) np.random.seed(seed) tf.set_random_seed(seed) output = sample.sample_sequence( hparams=hparams, length=length, start_token=enc.encoder['<|endoftext|>'] if not prefix else None, context=context if prefix else None, batch_size=batch_size, temperature=temperature, top_k=top_k )[:, 1:] if destination_path: f = open(destination_path, 'w') if prefix: context_tokens = enc.encode(prefix) generated = 0 gen_texts = [] while generated < nsamples: if not prefix: out = sess.run(output) else: out = sess.run(output, feed_dict={ context: batch_size * [context_tokens] }) for i in range(batch_size): generated += 1 gen_text = enc.decode(out[i]) if prefix and include_prefix: gen_text = prefix[0] + gen_text if truncate: raise NotImplementedError truncate_esc = re.escape(truncate) if prefix and not include_prefix: prefix_esc = re.escape(prefix) pattern = '(?:{})(.*?)(?:{})'.format(prefix_esc, truncate_esc) else: pattern = '(.*?)(?:{})'.format(truncate_esc) trunc_text = re.search(pattern, gen_text, re.S) if trunc_text: gen_text = trunc_text.group(1) if destination_path: f.write("{}\n{}".format(gen_text, sample_delim)) if not return_as_list and not destination_path: print("{}\n{}".format(gen_text, sample_delim)) gen_texts.append(gen_text) if destination_path: f.close() if return_as_list: return gen_texts
def generate_text(raw_text): """ Interactively run the model :model_name=117M : String, which model to use :seed=None : Integer seed for random number generators, fix seed to reproduce results :nsamples=1 : Number of samples to return total :batch_size=1 : Number of batches (only affects speed/memory). Must divide nsamples. :length=None : Number of tokens in generated text, if None (default), is determined by model hyperparameters :temperature=1 : Float value controlling randomness in boltzmann distribution. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions. :top_k=0 : Integer value controlling diversity. 1 means only 1 word is considered for each step (token), resulting in deterministic completions, while 40 means 40 words are considered at each step. 0 (default) is a special setting meaning no restrictions. 40 generally is a good value. :models_dir : path to parent folder containing model subfolders (i.e. contains the <model_name> folder) """ # Hardcoded parameters model_name = '345M' seed = None nsamples = 1 batch_size = 1 length = None temperature = 1 top_k = 40 # use absolute path if this doesn't work for you models_dir = '/mnt/c/Users/frch/Desktop/OneWeek2019/gpt2/models' # models_dir='gpt2\\models' ######################## models_dir = os.path.expanduser(os.path.expandvars(models_dir)) if batch_size is None: batch_size = 1 assert nsamples % batch_size == 0 enc = encoder.get_encoder(model_name, models_dir) hparams = model.default_hparams() with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if length is None: length = hparams.n_ctx // 2 elif length > hparams.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) with tf.Session(graph=tf.Graph()) as sess: context = tf.placeholder(tf.int32, [batch_size, None]) np.random.seed(seed) tf.set_random_seed(seed) output = sample.sample_sequence(hparams=hparams, length=length, context=context, batch_size=batch_size, temperature=temperature, top_k=top_k) saver = tf.train.Saver() ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) saver.restore(sess, ckpt) text = "" context_tokens = enc.encode(raw_text) generated = 0 for _ in range(nsamples // batch_size): out = sess.run(output, feed_dict={ context: [context_tokens for _ in range(batch_size)] })[:, len(context_tokens):] for i in range(batch_size): generated += 1 text += enc.decode(out[i]) return text