def __init__(self, path=None): super(GAN.Discriminator, self).__init__() settings.write_debug("Initializing discriminator object") if path is not None: settings.write_debug("Getting model from file") self.map1 = BertForSequenceClassification.from_pretrained(path) else: # Load BertForSequenceClassification, the pretrained BERT model with a single # linear classification layer on top. self.map1 = BertForSequenceClassification.from_pretrained( settings.get_model_type( ), # Use the 12-layer BERT model, with an uncased vocab. # num_labels=settings.get_num_labels(), # The number of output labels--2 for binary classification. num_labels=1, # You can increase this for multi-class tasks. output_attentions= False, # Whether the model returns attentions weights. output_hidden_states= False, # Whether the model returns all hidden-states. ) #Discriminator loss and optimizer are attributes of this object so that they can be different Generator objects if necessary self.loss_f = nn.BCEWithLogitsLoss() self.optimizer = AdamW(self.parameters(), lr=2e-5, eps=1e-8)
def train_discriminator(self, batch, labels): settings.write_debug("Begin Train Discriminator") output, loss = self(batch, labels) #Backward pass to accumulate the gradients loss.backward() #Clipts gradients that rise above 1.0 to prevent explosion torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0) settings.write_debug("Ending Train Discriminator") return output.mean().item(), loss
def train_discriminator(self, batch, labels): settings.write_debug("Begin Train Discriminator") output, loss = self(batch, labels) loss.backward() # Clip the norm of the gradients to 1 so no explosion torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0) settings.write_debug("Ending Train Discriminator") return output.mean().item(), loss
def train_generator(self, batch, labels, discriminator): settings.write_debug("Begin training generator") output, loss = self(batch, labels, discriminator) generator_loss = loss.item() loss.backward() torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0) settings.write_debug("Finish training generator") return output, generator_loss
def __generate_step(self, out: torch.tensor, gen_idx, temperature=None, top_k=0, sample=False, return_list=True): """ Generate a word from from out[gen_idx] args: - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size - gen_idx (int): location for which to generate for - top_k (int): if >0, only sample from the top k most probable words - sample (Bool): if True, sample from full distribution. Overridden by top_k """ logits = out[:, gen_idx] if temperature is not None: logits = logits / temperature if top_k > 0: ###### I think something is happening in gumbel softmax that breaks computation graph if topk==1 kth_vals, kth_idx = torch.topk(logits, top_k, dim=-1) gumbel_reps = torch.nn.functional.gumbel_softmax( logits=kth_vals, hard=True) idx = torch.sum(torch.mul(gumbel_reps, kth_idx), dim=-1) # del kth_vals, kth_idx # torch.cuda.empty_cache() elif sample: gumbel_reps = torch.nn.functional.gumbel_softmax(logits=logits, hard=True) idx = torch.sum(torch.mul( gumbel_reps, torch.tensor(range(0, len(tokenizer.vocab))).to(device)), dim=-1) else: idx = torch.argmax(logits, dim=-1) settings.write_debug("ERRORRRRR!!!!!!") logits.detach() del logits gumbel_reps.detach() del gumbel_reps gc.collect() return idx.tolist() if return_list else idx
def __init__(self, path=None): super(GAN.Generator, self).__init__() settings.write_debug("Initializing generator") if path is not None: settings.write_debug("Getting model from file") self.map1 = BertForMaskedLM.from_pretrained(path) else: self.map1 = BertForMaskedLM.from_pretrained( settings.get_model_type(), output_attentions=False, output_hidden_states=True) #Loss function and optimizer used for generation -- Attributes bc this could be changed to be different from discriminator self.loss_fct = nn.BCEWithLogitsLoss() self.optimizer = AdamW(self.parameters(), lr=2e-6, eps=10e-4)
def __generate_step(self, bert_output: torch.tensor, gen_idx, temperature=None, top_k=0, sample=False, return_list=True): #Generates a singl word token from bert propability distribution at gen_idx #Temperature controls the uniformity of the given probability distribution #Topk is the list of top probabilities in the distribution logits = bert_output[:, gen_idx] if temperature is not None: logits = logits / temperature if top_k > 0: ###### topk=1 is the same as using argmax. Caution! Breaks computation bc non-differentiable #Gumbel Softmax is implemented to allow for sampling with backprop kth_vals, kth_idx = torch.topk(logits, top_k, dim=-1) gumbel_reps = torch.nn.functional.gumbel_softmax( logits=kth_vals, hard=True) idx = torch.sum(torch.mul(gumbel_reps, kth_idx), dim=-1) elif sample: gumbel_reps = torch.nn.functional.gumbel_softmax(logits=logits, hard=True) idx = torch.sum(torch.mul( gumbel_reps, torch.tensor(range(0, len(tokenizer.vocab))).to(device)), dim=-1) else: #Argmax does not allow for backpropogation idx = torch.argmax(logits, dim=-1) settings.write_debug("ERRORRRRR!!!!!!") logits.detach() del logits gumbel_reps.detach() del gumbel_reps gc.collect() return idx.tolist() if return_list else idx
def generate(self, n_samples, seed_text="[CLS]", batch_size=10, max_len=25, sample_lens=None, generation_mode="parallel-sequential", sample=True, top_k=100, temperature=1.0, burnin=200, max_iter=500, print_every=1): # main generation function to call settings.write_debug("Generating Sentences") sentences = [] n_batches = math.ceil(n_samples / batch_size) start_time = time.time() for batch_n in range(n_batches): if generation_mode == "evaluation": batch = self.__evaluation_generation( seed_text, batch_size=batch_size, max_len=max_len, top_k=top_k, temperature=temperature, burnin=burnin, max_iter=max_iter, sample_lens=sample_lens, verbose=False) elif generation_mode == "training": batch = self.__training_generation(seed_text, batch_size=batch_size, max_len=max_len, top_k=top_k, temperature=temperature, burnin=burnin, max_iter=max_iter, sample_lens=sample_lens, verbose=False) if (batch_n + 1) % print_every == 0: settings.write_debug( "Finished generating batch %d in %.3fs" % (batch_n + 1, time.time() - start_time)) start_time = time.time() sentences += batch del batch torch.cuda.empty_cache() settings.write_debug("Returning Generated Sentences") return torch.stack(sentences)
def __init__(self, gen_file_path: str = None, disc_file_path: str = None): super(GAN, self).__init__() settings.write_debug("Initializing GAN") self.generator = GAN.Generator(gen_file_path) self.discriminator = GAN.Discriminator(disc_file_path)
def train_gan(self, training_dataloader: DataLoader, validation_dataloader: DataLoader): settings.write_debug("Begin Training") self.initialize_training(training_dataloader) settings.write_debug("========================================") settings.write_debug(" Training ") settings.write_debug("========================================") # Training loop -- Iterates over full training dataset for epoch_i in range(0, self.epochs): settings.write_debug("") settings.write_debug('======== Epoch {:} / {:} ========'.format( epoch_i + 1, self.epochs)) settings.write_debug('Training...') # Initializes time variable to measure duration of training t0 = time.time() # Reset the loss variables for each epoch discriminator_total_loss = 0 generator_total_loss = 0 gen_update_step = 0 ######################################################## # Get one batch of real samples from dataset at a time ####################################################### for step, batch in enumerate(training_dataloader): ######################## # Generate Labels ######################## # Set labels using random float values for label smoothing # - Real labels are range [.9,1] instead of just 1 # - Fake labels are range [0,.1] instead of just 0 # - Helps to prevent mode collapse by keeping a moving target self.real_labels = torch.tensor( [random.uniform(.9, 1)] * settings.get_batch_size(), requires_grad=False).unsqueeze(-1).to(device) self.false_labels = torch.tensor( [random.uniform(0, .1)] * settings.get_batch_size(), requires_grad=False).unsqueeze(-1).to(device) settings.write_debug("batch:" + str(step) + "/" + str(len(training_dataloader))) # Progress update every 40 batches. if step % 40 == 0 and not step == 0: # Calculate elapsed time in minutes. elapsed = format_time(time.time() - t0) # Report progress. settings.write_debug( ' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format( step, len(training_dataloader), elapsed)) ########################### # Discriminator Network Training ############################# # - Real samples only to maximize log(D(x)) + log(1 - D(G(z))) self.discriminator.train() self.generator.train() #self.discriminator.requires_grad = True # clears previously accumulated gradients by setting to all zeros self.discriminator.zero_grad() # attaches batch to GPU batch = batch.to(device) ## Trains with batch of samples from the dataset -- "real" D_x, discriminator_real_loss = self.discriminator.train_discriminator( batch, self.real_labels ) ### Adds the correct list of labels to the batch ############################################################################ ## Generate a "fake" batch of samples similar in size and length to the "real" ########################################################################## n_samples = batch_size = settings.get_batch_size() sample_lens = ( (batch != 0).sum(dim=-1) - 2).tolist() # Subtract 2 to account for CLS and final SEP max_len = len(batch[0]) - 2 # settings.get_sample_size() top_k = 10 ### *** Don't Use 1 here because it breaks computation *** temperature = 1.0 generation_mode = "training" burnin = 0 sample = True max_iter = 1 # Choose the prefix context seed_text = "[CLS]".split() generated_input_ids = self.generator.generate( n_samples, seed_text=seed_text, batch_size=batch_size, max_len=max_len, sample_lens=sample_lens, generation_mode=generation_mode, sample=sample, top_k=top_k, temperature=temperature, burnin=burnin, max_iter=max_iter) bert_sents = [ self.generator.detokenize( tokenizer.convert_ids_to_tokens(sent.tolist())).split() for sent in generated_input_ids ] with open(settings.get_bert_train_out_path(), 'a+') as f: [ f.write('[%d/%d][%d/%d]\t' % (epoch_i, self.epochs, index, len(bert_sents)) + " ".join(sent) + "\n") for index, sent in enumerate(bert_sents) ] D_G_z1, discriminator_fake_loss = self.discriminator.train_discriminator( generated_input_ids, #clone().detach(), self.false_labels ) ### Adds the correct list of labels to the batch discriminator_combined_loss = discriminator_fake_loss + discriminator_real_loss discriminator_total_loss += discriminator_combined_loss self.discriminator.optimizer.step() # Update the learning rate. self.discriminator_scheduler.step() ########################### # Discriminator Network Training ############################# # - Generated samples only to maximizes log(D(G(z))) # Save gpu memory by untracking discriminator gradients #self.discriminator.requires_grad = False ## Train with generated batch D_G_z2, generator_loss = self.generator.train_generator( generated_input_ids, self.real_labels, self.discriminator) # Call step to optimizer to update weights # Step to scheduler modifies learning rate self.generator.optimizer.step() self.generator_scheduler.step() generator_total_loss += generator_loss # counter gen_update_step += 1 # Clear gradients after update, Detach and empty cache to save on memory self.generator.zero_grad() generated_input_ids.detach() del generated_input_ids torch.cuda.empty_cache() # Output training stats settings.write_train_stat( '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f\n' % (epoch_i, self.epochs, step, len(training_dataloader), discriminator_combined_loss, generator_loss, D_x, D_G_z1, D_G_z2)) # Calculate the average loss over the training data. discriminator_avg_train_loss = discriminator_total_loss / ( len(training_dataloader)) generator_avg_train_loss = generator_total_loss / gen_update_step ### output statistics to file settings.write_train_stat("\n") settings.write_train_stat( " Average Discriminator training loss: {0:.2f}".format( discriminator_avg_train_loss)) settings.write_train_stat( " Average Generator training loss: {0:.2f}".format( generator_avg_train_loss)) settings.write_train_stat(" Training epcoh took: {:}\n".format( format_time(time.time() - t0))) settings.write_debug("========================================") settings.write_debug(" Validation ") settings.write_debug("========================================") # After training epoch, measure accuracy using mixture of validation set from real dataset and more generated samples to match settings.write_debug("") settings.write_debug("Running Validation...") t0 = time.time() # eval changes behavior of dropout layers self.generator.eval() self.discriminator.eval() # reset variables eval_loss, eval_accuracy = 0., 0. nb_eval_steps, nb_eval_examples = 0, 0 # Evaluate data for one epoch for batch in validation_dataloader: with torch.no_grad(): ## Generate an all-fake batch sample_lens = ( (batch != 0).sum(dim=-1) - 2).tolist() # Subtract 2 to account for CLS and final SEP n_samples = batch_size = settings.get_batch_size() max_len = len(batch[0]) - 2 # settings.get_sample_size() top_k = 100 ### Using 1 here seems to break computation for some reason temperature = 1.0 generation_mode = "evaluation" burnin = 250 sample = True max_iter = 500 seed_text = "[CLS]".split() generated_input_ids = self.generator.generate( n_samples, seed_text=seed_text, batch_size=batch_size, max_len=max_len, generation_mode=generation_mode, sample=sample, top_k=top_k, temperature=temperature, burnin=burnin, max_iter=max_iter, sample_lens=sample_lens) validation_sents = [ self.generator.detokenize( tokenizer.convert_ids_to_tokens( sent.tolist())).split() for sent in generated_input_ids ] with open(settings.get_bert_valid_out_path(), 'a+') as f: [ f.write('[%d/%d][%d/%d]\t' % (epoch_i, self.epochs, index, len(validation_sents)) + " ".join(sent) + "\n") for index, sent in enumerate(validation_sents) ] batch = torch.cat((batch, generated_input_ids)).to(device) labels = torch.cat((self.real_labels, self.false_labels)) # Matches labels with samples and shuffles them accordingly batch = list(zip(batch, labels)) random.shuffle(batch) batch, labels = zip(*batch) logits, _ = self.discriminator(torch.stack(batch)) print("Validation logits", flush=True) print(logits, flush=True) # Calculate the acc for this batch of mixed real and generated validation tmp_eval_accuracy = _flat_accuracy(preds=logits, labels=torch.stack(labels)) # save total acc eval_accuracy += tmp_eval_accuracy # total number of validation batches run nb_eval_steps += 1 print(eval_accuracy / nb_eval_steps) # Output acc settings.write_train_stat("\n") settings.write_train_stat(" Validation Accuracy: {0:.2f}".format( eval_accuracy / nb_eval_steps)) settings.write_train_stat(" Validation took: {:}\n".format( format_time(time.time() - t0))) # Put both models back into training mode now that validation has ended self.generator.train() self.discriminator.train() settings.write_debug("") settings.write_debug("Training complete!")