def get_data(): wiki103_file_path = settings.get_raw_wiki_path() tbc_file_path = settings.get_raw_tbc_path() #Data cleaning and preparation according to previous work wiki_data = prepare_wiki(wiki103_file_path) tbc_data = prepare_tbc(tbc_file_path) input_sents = wiki_data + tbc_data #Method groups batches of data by normalizing length of BERt tokens smart_batches = create_smart_batches(input_sents, batch_size=settings.get_batch_size()) # Divide data into training and validation train_inputs, validation_inputs = train_test_split( smart_batches[:settings.get_num_batches()], random_state=settings.get_random_state(), test_size=settings.get_test_size_ratio()) #Writes prepared data to file to be saved with open(settings.get_train_inputs_path(), 'w+') as f: [ f.write(" ".join(tokenizer.convert_ids_to_tokens(sent)[1:-1]) + "\n") for batch in train_inputs for sent in batch ] with open(settings.get_validation_inputs_path(), 'w+') as f: [ f.write(" ".join(tokenizer.convert_ids_to_tokens(sent)[1:-1]) + "\n") for batch in validation_inputs for sent in batch ] with open(settings.get_proc_wiki_path(), 'wb') as f: pickle.dump(wiki_data, f) with open(settings.get_proc_tbc_path(), 'wb') as f: pickle.dump(tbc_data, f)
def evaluate_bert(generator, output_file_path:str): settings.write_debug("Entering Bleu EVAL") with torch.no_grad(): ### Using 1 eval topk here breaks computation because it is essentially argmax and non-differentiable # Choose the prefix context generated_input_ids = generator.generate(settings.get_num_eval_samples(), seed_text=settings.get_eval_seed_text().split(), batch_size=settings.get_eval_batch_size(), max_len=settings.get_sample_size(), generation_mode=settings.get_eval_gen_mode_key(), sample=settings.get_eval_sample(), top_k=settings.get_eval_top_k(), temperature=settings.get_eval_temp(), burnin=settings.get_eval_burnin(), max_iter=settings.get_eval_max_iter()) bert_sents = [generator.detokenize(tokenizer.convert_ids_to_tokens(sent.tolist())).split() for sent in generated_input_ids] with open(output_file_path, 'a+') as f: [f.write('[%d/%d]\t' % (index, len(bert_sents)) + " ".join(sent) + "\n") for index, sent in enumerate(bert_sents)] avg_p = np.average([(perplexity_model.score(" ".join(sent))['positional_scores'].mean().neg().exp()).item() for sent in bert_sents]) settings.write_result("BERT Perplexity: %.2f" % avg_p) settings.write_result("BERT self-BLEU: %.2f" % (100 * self_bleu(bert_sents))) max_n = settings.get_bleu_max_n() with open(settings.get_proc_wiki_path(), 'rb') as proc_wiki_f, open(settings.get_proc_tbc_path(), 'rb') as proc_tbc_f: wiki_data = pickle.load(proc_wiki_f) tbc_data = pickle.load(proc_tbc_f) settings.write_result("BERT-TBC BLEU: %.2f" % (100 * corpus_bleu(bert_sents, tbc_data))) settings.write_result("BERT-Wiki103 BLEU: %.2f" % (100 * corpus_bleu(bert_sents, wiki_data))) settings.write_result("BERT-{TBC + Wiki103} BLEU: %.2f" % (100 * corpus_bleu(bert_sents, tbc_data[:2500] + wiki_data[:2500]))) pct_uniques = ref_unique_ngrams(bert_sents, wiki_data, max_n) for i in range(1, max_n + 1): settings.write_result("BERT unique %d-grams relative to Wiki: %.2f" % (i, 100 * pct_uniques[i])) pct_uniques = ref_unique_ngrams(bert_sents, tbc_data, max_n) for i in range(1, max_n + 1): settings.write_result("BERT unique %d-grams relative to TBC: %.2f" % (i, 100 * pct_uniques[i])) pct_uniques = self_unique_ngrams(bert_sents, max_n) for i in range(1, max_n + 1): settings.write_result("BERT unique %d-grams relative to self: %.2f" % (i, 100 * pct_uniques[i])) settings.write_result("")
def untokenize_batch(batch, tokenizer: BertTokenizer): return [tokenizer.convert_ids_to_tokens(sent) for sent in batch]
def train_gan(self, training_dataloader: DataLoader, validation_dataloader: DataLoader): settings.write_debug("Begin Training") self.initialize_training(training_dataloader) settings.write_debug("========================================") settings.write_debug(" Training ") settings.write_debug("========================================") # Training loop -- Iterates over full training dataset for epoch_i in range(0, self.epochs): settings.write_debug("") settings.write_debug('======== Epoch {:} / {:} ========'.format( epoch_i + 1, self.epochs)) settings.write_debug('Training...') # Initializes time variable to measure duration of training t0 = time.time() # Reset the loss variables for each epoch discriminator_total_loss = 0 generator_total_loss = 0 gen_update_step = 0 ######################################################## # Get one batch of real samples from dataset at a time ####################################################### for step, batch in enumerate(training_dataloader): ######################## # Generate Labels ######################## # Set labels using random float values for label smoothing # - Real labels are range [.9,1] instead of just 1 # - Fake labels are range [0,.1] instead of just 0 # - Helps to prevent mode collapse by keeping a moving target self.real_labels = torch.tensor( [random.uniform(.9, 1)] * settings.get_batch_size(), requires_grad=False).unsqueeze(-1).to(device) self.false_labels = torch.tensor( [random.uniform(0, .1)] * settings.get_batch_size(), requires_grad=False).unsqueeze(-1).to(device) settings.write_debug("batch:" + str(step) + "/" + str(len(training_dataloader))) # Progress update every 40 batches. if step % 40 == 0 and not step == 0: # Calculate elapsed time in minutes. elapsed = format_time(time.time() - t0) # Report progress. settings.write_debug( ' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format( step, len(training_dataloader), elapsed)) ########################### # Discriminator Network Training ############################# # - Real samples only to maximize log(D(x)) + log(1 - D(G(z))) self.discriminator.train() self.generator.train() #self.discriminator.requires_grad = True # clears previously accumulated gradients by setting to all zeros self.discriminator.zero_grad() # attaches batch to GPU batch = batch.to(device) ## Trains with batch of samples from the dataset -- "real" D_x, discriminator_real_loss = self.discriminator.train_discriminator( batch, self.real_labels ) ### Adds the correct list of labels to the batch ############################################################################ ## Generate a "fake" batch of samples similar in size and length to the "real" ########################################################################## n_samples = batch_size = settings.get_batch_size() sample_lens = ( (batch != 0).sum(dim=-1) - 2).tolist() # Subtract 2 to account for CLS and final SEP max_len = len(batch[0]) - 2 # settings.get_sample_size() top_k = 10 ### *** Don't Use 1 here because it breaks computation *** temperature = 1.0 generation_mode = "training" burnin = 0 sample = True max_iter = 1 # Choose the prefix context seed_text = "[CLS]".split() generated_input_ids = self.generator.generate( n_samples, seed_text=seed_text, batch_size=batch_size, max_len=max_len, sample_lens=sample_lens, generation_mode=generation_mode, sample=sample, top_k=top_k, temperature=temperature, burnin=burnin, max_iter=max_iter) bert_sents = [ self.generator.detokenize( tokenizer.convert_ids_to_tokens(sent.tolist())).split() for sent in generated_input_ids ] with open(settings.get_bert_train_out_path(), 'a+') as f: [ f.write('[%d/%d][%d/%d]\t' % (epoch_i, self.epochs, index, len(bert_sents)) + " ".join(sent) + "\n") for index, sent in enumerate(bert_sents) ] D_G_z1, discriminator_fake_loss = self.discriminator.train_discriminator( generated_input_ids, #clone().detach(), self.false_labels ) ### Adds the correct list of labels to the batch discriminator_combined_loss = discriminator_fake_loss + discriminator_real_loss discriminator_total_loss += discriminator_combined_loss self.discriminator.optimizer.step() # Update the learning rate. self.discriminator_scheduler.step() ########################### # Discriminator Network Training ############################# # - Generated samples only to maximizes log(D(G(z))) # Save gpu memory by untracking discriminator gradients #self.discriminator.requires_grad = False ## Train with generated batch D_G_z2, generator_loss = self.generator.train_generator( generated_input_ids, self.real_labels, self.discriminator) # Call step to optimizer to update weights # Step to scheduler modifies learning rate self.generator.optimizer.step() self.generator_scheduler.step() generator_total_loss += generator_loss # counter gen_update_step += 1 # Clear gradients after update, Detach and empty cache to save on memory self.generator.zero_grad() generated_input_ids.detach() del generated_input_ids torch.cuda.empty_cache() # Output training stats settings.write_train_stat( '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f\n' % (epoch_i, self.epochs, step, len(training_dataloader), discriminator_combined_loss, generator_loss, D_x, D_G_z1, D_G_z2)) # Calculate the average loss over the training data. discriminator_avg_train_loss = discriminator_total_loss / ( len(training_dataloader)) generator_avg_train_loss = generator_total_loss / gen_update_step ### output statistics to file settings.write_train_stat("\n") settings.write_train_stat( " Average Discriminator training loss: {0:.2f}".format( discriminator_avg_train_loss)) settings.write_train_stat( " Average Generator training loss: {0:.2f}".format( generator_avg_train_loss)) settings.write_train_stat(" Training epcoh took: {:}\n".format( format_time(time.time() - t0))) settings.write_debug("========================================") settings.write_debug(" Validation ") settings.write_debug("========================================") # After training epoch, measure accuracy using mixture of validation set from real dataset and more generated samples to match settings.write_debug("") settings.write_debug("Running Validation...") t0 = time.time() # eval changes behavior of dropout layers self.generator.eval() self.discriminator.eval() # reset variables eval_loss, eval_accuracy = 0., 0. nb_eval_steps, nb_eval_examples = 0, 0 # Evaluate data for one epoch for batch in validation_dataloader: with torch.no_grad(): ## Generate an all-fake batch sample_lens = ( (batch != 0).sum(dim=-1) - 2).tolist() # Subtract 2 to account for CLS and final SEP n_samples = batch_size = settings.get_batch_size() max_len = len(batch[0]) - 2 # settings.get_sample_size() top_k = 100 ### Using 1 here seems to break computation for some reason temperature = 1.0 generation_mode = "evaluation" burnin = 250 sample = True max_iter = 500 seed_text = "[CLS]".split() generated_input_ids = self.generator.generate( n_samples, seed_text=seed_text, batch_size=batch_size, max_len=max_len, generation_mode=generation_mode, sample=sample, top_k=top_k, temperature=temperature, burnin=burnin, max_iter=max_iter, sample_lens=sample_lens) validation_sents = [ self.generator.detokenize( tokenizer.convert_ids_to_tokens( sent.tolist())).split() for sent in generated_input_ids ] with open(settings.get_bert_valid_out_path(), 'a+') as f: [ f.write('[%d/%d][%d/%d]\t' % (epoch_i, self.epochs, index, len(validation_sents)) + " ".join(sent) + "\n") for index, sent in enumerate(validation_sents) ] batch = torch.cat((batch, generated_input_ids)).to(device) labels = torch.cat((self.real_labels, self.false_labels)) # Matches labels with samples and shuffles them accordingly batch = list(zip(batch, labels)) random.shuffle(batch) batch, labels = zip(*batch) logits, _ = self.discriminator(torch.stack(batch)) print("Validation logits", flush=True) print(logits, flush=True) # Calculate the acc for this batch of mixed real and generated validation tmp_eval_accuracy = _flat_accuracy(preds=logits, labels=torch.stack(labels)) # save total acc eval_accuracy += tmp_eval_accuracy # total number of validation batches run nb_eval_steps += 1 print(eval_accuracy / nb_eval_steps) # Output acc settings.write_train_stat("\n") settings.write_train_stat(" Validation Accuracy: {0:.2f}".format( eval_accuracy / nb_eval_steps)) settings.write_train_stat(" Validation took: {:}\n".format( format_time(time.time() - t0))) # Put both models back into training mode now that validation has ended self.generator.train() self.discriminator.train() settings.write_debug("") settings.write_debug("Training complete!")