示例#1
0
def get_data():

    wiki103_file_path = settings.get_raw_wiki_path()
    tbc_file_path = settings.get_raw_tbc_path()

    #Data cleaning and preparation according to previous work
    wiki_data = prepare_wiki(wiki103_file_path)
    tbc_data = prepare_tbc(tbc_file_path)
    input_sents = wiki_data + tbc_data

    #Method groups batches of data by normalizing length of BERt tokens
    smart_batches = create_smart_batches(input_sents,
                                         batch_size=settings.get_batch_size())

    # Divide data into training and validation
    train_inputs, validation_inputs = train_test_split(
        smart_batches[:settings.get_num_batches()],
        random_state=settings.get_random_state(),
        test_size=settings.get_test_size_ratio())

    #Writes prepared data to file to be saved
    with open(settings.get_train_inputs_path(), 'w+') as f:
        [
            f.write(" ".join(tokenizer.convert_ids_to_tokens(sent)[1:-1]) +
                    "\n") for batch in train_inputs for sent in batch
        ]

    with open(settings.get_validation_inputs_path(), 'w+') as f:
        [
            f.write(" ".join(tokenizer.convert_ids_to_tokens(sent)[1:-1]) +
                    "\n") for batch in validation_inputs for sent in batch
        ]

    with open(settings.get_proc_wiki_path(), 'wb') as f:
        pickle.dump(wiki_data, f)

    with open(settings.get_proc_tbc_path(), 'wb') as f:
        pickle.dump(tbc_data, f)
def evaluate_bert(generator, output_file_path:str):
    settings.write_debug("Entering Bleu EVAL")
    with torch.no_grad():

        ### Using 1 eval topk here breaks computation because it is essentially argmax and non-differentiable
        # Choose the prefix context
        generated_input_ids = generator.generate(settings.get_num_eval_samples(),
                                                 seed_text=settings.get_eval_seed_text().split(),
                                                 batch_size=settings.get_eval_batch_size(),
                                                 max_len=settings.get_sample_size(),
                                                 generation_mode=settings.get_eval_gen_mode_key(),
                                                 sample=settings.get_eval_sample(),
                                                 top_k=settings.get_eval_top_k(),
                                                 temperature=settings.get_eval_temp(),
                                                 burnin=settings.get_eval_burnin(),
                                                 max_iter=settings.get_eval_max_iter())

        bert_sents = [generator.detokenize(tokenizer.convert_ids_to_tokens(sent.tolist())).split() for sent in generated_input_ids]
        with open(output_file_path, 'a+') as f:
            [f.write('[%d/%d]\t' % (index, len(bert_sents)) + " ".join(sent) + "\n") for index, sent in enumerate(bert_sents)]

        avg_p = np.average([(perplexity_model.score(" ".join(sent))['positional_scores'].mean().neg().exp()).item() for sent in bert_sents])

        settings.write_result("BERT Perplexity: %.2f" % avg_p)

        settings.write_result("BERT self-BLEU: %.2f" % (100 * self_bleu(bert_sents)))

        max_n = settings.get_bleu_max_n()

        with open(settings.get_proc_wiki_path(), 'rb') as proc_wiki_f, open(settings.get_proc_tbc_path(), 'rb') as proc_tbc_f:

            wiki_data = pickle.load(proc_wiki_f)
            tbc_data = pickle.load(proc_tbc_f)

            settings.write_result("BERT-TBC BLEU: %.2f" % (100 * corpus_bleu(bert_sents, tbc_data)))
            settings.write_result("BERT-Wiki103 BLEU: %.2f" % (100 * corpus_bleu(bert_sents, wiki_data)))
            settings.write_result("BERT-{TBC + Wiki103} BLEU: %.2f" % (100 * corpus_bleu(bert_sents, tbc_data[:2500] + wiki_data[:2500])))

            pct_uniques = ref_unique_ngrams(bert_sents, wiki_data, max_n)
            for i in range(1, max_n + 1):
                settings.write_result("BERT unique %d-grams relative to Wiki: %.2f" % (i, 100 * pct_uniques[i]))
            pct_uniques = ref_unique_ngrams(bert_sents, tbc_data, max_n)
            for i in range(1, max_n + 1):
                settings.write_result("BERT unique %d-grams relative to TBC: %.2f" % (i, 100 * pct_uniques[i]))
            pct_uniques = self_unique_ngrams(bert_sents, max_n)
            for i in range(1, max_n + 1):
                settings.write_result("BERT unique %d-grams relative to self: %.2f" % (i, 100 * pct_uniques[i]))

        settings.write_result("")
示例#3
0
 def untokenize_batch(batch, tokenizer: BertTokenizer):
     return [tokenizer.convert_ids_to_tokens(sent) for sent in batch]
示例#4
0
def train_gan(self, training_dataloader: DataLoader,
              validation_dataloader: DataLoader):
    settings.write_debug("Begin Training")

    self.initialize_training(training_dataloader)

    settings.write_debug("========================================")
    settings.write_debug("               Training                 ")
    settings.write_debug("========================================")

    # Training loop -- Iterates over full training dataset
    for epoch_i in range(0, self.epochs):

        settings.write_debug("")
        settings.write_debug('======== Epoch {:} / {:} ========'.format(
            epoch_i + 1, self.epochs))
        settings.write_debug('Training...')

        # Initializes time variable to measure duration of training
        t0 = time.time()

        # Reset the loss variables for each epoch
        discriminator_total_loss = 0
        generator_total_loss = 0
        gen_update_step = 0

        ########################################################
        # Get one batch of real samples from dataset at a time
        #######################################################
        for step, batch in enumerate(training_dataloader):

            ########################
            # Generate Labels
            ########################
            #   Set labels using random float values for label smoothing
            #       - Real labels are range [.9,1] instead of just 1
            #       - Fake labels are range [0,.1] instead of just 0
            #       - Helps to prevent mode collapse by keeping a moving target
            self.real_labels = torch.tensor(
                [random.uniform(.9, 1)] * settings.get_batch_size(),
                requires_grad=False).unsqueeze(-1).to(device)
            self.false_labels = torch.tensor(
                [random.uniform(0, .1)] * settings.get_batch_size(),
                requires_grad=False).unsqueeze(-1).to(device)

            settings.write_debug("batch:" + str(step) + "/" +
                                 str(len(training_dataloader)))

            # Progress update every 40 batches.
            if step % 40 == 0 and not step == 0:
                # Calculate elapsed time in minutes.
                elapsed = format_time(time.time() - t0)
                # Report progress.
                settings.write_debug(
                    '  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(
                        step, len(training_dataloader), elapsed))

            ###########################
            # Discriminator Network Training
            #############################
            # - Real samples only to maximize log(D(x)) + log(1 - D(G(z)))
            self.discriminator.train()
            self.generator.train()
            #self.discriminator.requires_grad = True

            # clears previously accumulated gradients by setting to all zeros
            self.discriminator.zero_grad()

            # attaches batch to GPU
            batch = batch.to(device)

            ## Trains with batch of samples from the dataset -- "real"
            D_x, discriminator_real_loss = self.discriminator.train_discriminator(
                batch, self.real_labels
            )  ### Adds the correct list of labels to the batch

            ############################################################################
            ## Generate a "fake" batch of samples similar in size and length to the "real"
            ##########################################################################
            n_samples = batch_size = settings.get_batch_size()
            sample_lens = (
                (batch != 0).sum(dim=-1) -
                2).tolist()  # Subtract 2 to account for CLS and final SEP
            max_len = len(batch[0]) - 2  # settings.get_sample_size()
            top_k = 10  ### *** Don't Use 1 here because it breaks computation ***
            temperature = 1.0
            generation_mode = "training"
            burnin = 0
            sample = True
            max_iter = 1

            # Choose the prefix context
            seed_text = "[CLS]".split()
            generated_input_ids = self.generator.generate(
                n_samples,
                seed_text=seed_text,
                batch_size=batch_size,
                max_len=max_len,
                sample_lens=sample_lens,
                generation_mode=generation_mode,
                sample=sample,
                top_k=top_k,
                temperature=temperature,
                burnin=burnin,
                max_iter=max_iter)

            bert_sents = [
                self.generator.detokenize(
                    tokenizer.convert_ids_to_tokens(sent.tolist())).split()
                for sent in generated_input_ids
            ]
            with open(settings.get_bert_train_out_path(), 'a+') as f:
                [
                    f.write('[%d/%d][%d/%d]\t' %
                            (epoch_i, self.epochs, index, len(bert_sents)) +
                            " ".join(sent) + "\n")
                    for index, sent in enumerate(bert_sents)
                ]

            D_G_z1, discriminator_fake_loss = self.discriminator.train_discriminator(
                generated_input_ids,  #clone().detach(),
                self.false_labels
            )  ### Adds the correct list of labels to the batch

            discriminator_combined_loss = discriminator_fake_loss + discriminator_real_loss

            discriminator_total_loss += discriminator_combined_loss

            self.discriminator.optimizer.step()
            # Update the learning rate.
            self.discriminator_scheduler.step()

            ###########################
            # Discriminator Network Training
            #############################
            # - Generated samples only to maximizes log(D(G(z)))

            # Save gpu memory by untracking discriminator gradients

            #self.discriminator.requires_grad = False

            ## Train with generated batch
            D_G_z2, generator_loss = self.generator.train_generator(
                generated_input_ids, self.real_labels, self.discriminator)

            # Call step to optimizer to update weights
            # Step to scheduler modifies learning rate
            self.generator.optimizer.step()
            self.generator_scheduler.step()
            generator_total_loss += generator_loss

            # counter
            gen_update_step += 1

            # Clear gradients after update, Detach and empty cache to save on memory
            self.generator.zero_grad()
            generated_input_ids.detach()
            del generated_input_ids
            torch.cuda.empty_cache()

            # Output training stats
            settings.write_train_stat(
                '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f\n'
                % (epoch_i, self.epochs, step, len(training_dataloader),
                   discriminator_combined_loss, generator_loss, D_x, D_G_z1,
                   D_G_z2))

        # Calculate the average loss over the training data.
        discriminator_avg_train_loss = discriminator_total_loss / (
            len(training_dataloader))
        generator_avg_train_loss = generator_total_loss / gen_update_step

        ### output statistics to file
        settings.write_train_stat("\n")
        settings.write_train_stat(
            "  Average Discriminator training loss: {0:.2f}".format(
                discriminator_avg_train_loss))
        settings.write_train_stat(
            "  Average Generator training loss: {0:.2f}".format(
                generator_avg_train_loss))
        settings.write_train_stat("  Training epcoh took: {:}\n".format(
            format_time(time.time() - t0)))

        settings.write_debug("========================================")
        settings.write_debug("              Validation                ")
        settings.write_debug("========================================")

        # After training epoch, measure accuracy using mixture of validation set from real dataset and more generated samples to match

        settings.write_debug("")
        settings.write_debug("Running Validation...")

        t0 = time.time()

        # eval changes behavior of dropout layers
        self.generator.eval()
        self.discriminator.eval()

        # reset variables
        eval_loss, eval_accuracy = 0., 0.
        nb_eval_steps, nb_eval_examples = 0, 0

        # Evaluate data for one epoch
        for batch in validation_dataloader:
            with torch.no_grad():
                ## Generate an all-fake batch

                sample_lens = (
                    (batch != 0).sum(dim=-1) -
                    2).tolist()  # Subtract 2 to account for CLS and final SEP

                n_samples = batch_size = settings.get_batch_size()
                max_len = len(batch[0]) - 2  # settings.get_sample_size()
                top_k = 100  ### Using 1 here seems to break computation for some reason
                temperature = 1.0
                generation_mode = "evaluation"
                burnin = 250
                sample = True
                max_iter = 500
                seed_text = "[CLS]".split()
                generated_input_ids = self.generator.generate(
                    n_samples,
                    seed_text=seed_text,
                    batch_size=batch_size,
                    max_len=max_len,
                    generation_mode=generation_mode,
                    sample=sample,
                    top_k=top_k,
                    temperature=temperature,
                    burnin=burnin,
                    max_iter=max_iter,
                    sample_lens=sample_lens)

                validation_sents = [
                    self.generator.detokenize(
                        tokenizer.convert_ids_to_tokens(
                            sent.tolist())).split()
                    for sent in generated_input_ids
                ]
                with open(settings.get_bert_valid_out_path(), 'a+') as f:
                    [
                        f.write('[%d/%d][%d/%d]\t' %
                                (epoch_i, self.epochs, index,
                                 len(validation_sents)) + " ".join(sent) +
                                "\n")
                        for index, sent in enumerate(validation_sents)
                    ]

                batch = torch.cat((batch, generated_input_ids)).to(device)
                labels = torch.cat((self.real_labels, self.false_labels))

                # Matches labels with samples and shuffles them accordingly
                batch = list(zip(batch, labels))
                random.shuffle(batch)
                batch, labels = zip(*batch)

                logits, _ = self.discriminator(torch.stack(batch))

                print("Validation logits", flush=True)
                print(logits, flush=True)
                # Calculate the acc for this batch of mixed real and generated validation

                tmp_eval_accuracy = _flat_accuracy(preds=logits,
                                                   labels=torch.stack(labels))

                # save total acc
                eval_accuracy += tmp_eval_accuracy

                # total number of validation batches run
                nb_eval_steps += 1

                print(eval_accuracy / nb_eval_steps)

        # Output acc
        settings.write_train_stat("\n")
        settings.write_train_stat("  Validation Accuracy: {0:.2f}".format(
            eval_accuracy / nb_eval_steps))
        settings.write_train_stat("  Validation took: {:}\n".format(
            format_time(time.time() - t0)))

        # Put both models back into training mode now that validation has ended
        self.generator.train()
        self.discriminator.train()

    settings.write_debug("")
    settings.write_debug("Training complete!")