Exemplo n.º 1
0
    def train(self):
        """ Main training method for the Trainer class """

        print("Starting training for {} epoch(s)".format(self.max_num_epochs -
                                                         self.epoch))

        if not self.params.boost_warmup:
            hard_training_instances = []

        for epoch in range(self.max_num_epochs):
            self.epoch = epoch
            print("Epoch {}/{}".format(epoch + 1, self.max_num_epochs))

            # train the model the train set
            epoch_start_time = time.time()

            # Make a copy of train_iter, add new examples to it (if boost==True),
            # and pass it into train_epoch()
            data_iterator = self.train_iter

            # If boost==True and epochs are past warmup, perform boosting
            if self.params.boost and epoch + 1 > self.params.boost_warmup:
                print("Boosting....")

                # make `Example` objects for all hard training instances
                example_objs = self.create_example_objs(
                    hard_training_instances)

                # Add the new hard training instances to the original training data
                # thereby `boosting` the dataset with hard training examples
                existing_data = self.train_iter.data()
                existing_data.extend(example_objs)

                # Create new Dataset and iterator on the boosted data
                data_iterator = self.create_boosted_dataset(existing_data)

            train_loss_avg, hard_training_instances = self.train_epoch(
                data_iterator)

            # write epoch statistics to Tensorboard
            self.summary_writer.add_scalar("train/avg_loss_per_epoch",
                                           train_loss_avg, self.epoch)
            self.summary_writer.add_scalar("train/avg_perplexity_epoch",
                                           math.exp(train_loss_avg),
                                           self.epoch)

            epoch_end_time = time.time()
            epoch_mins, epoch_secs = self.epoch_time(epoch_start_time,
                                                     epoch_end_time)
            print(
                f'Epoch: {epoch+1:02} | Avg Train Loss: {train_loss_avg} | Perpelxity: {math.exp(train_loss_avg)} | Time: {epoch_mins}m {epoch_secs}s'
            )

            # validate the model on the dev set
            val_start_time = time.time()
            val_loss_avg = self.validate()
            val_end_time = time.time()
            val_mins, val_secs = self.epoch_time(val_start_time, val_end_time)

            # write validation statistics to Tensorboard
            self.summary_writer.add_scalar("val/loss", val_loss_avg,
                                           self.epoch)
            self.summary_writer.add_scalar("val/perplexity",
                                           math.exp(val_loss_avg), self.epoch)

            # TODO: write translations to Tensorboard
            # every `decode_every_num_epochs` epochs, write out translations using Greedy Decoding
            # to Tensorboard
            if (self.epoch + 1) % self.decode_every_num_epochs == 0:
                print("Performing Greedy Decoding...")
                num_translations = 5
                dev_iter = copy.copy(self.dev_iter)
                decoder = Translator(
                    model=self.model,
                    dev_iter=list(dev_iter)[:num_translations],
                    params=self.params,
                    device=self.params.device)
                translations = decoder.greedy_decode(max_len=100)
                translations = [
                    " ".join(translation) for translation in translations
                ]
                for translation in translations:
                    self.summary_writer.add_text("transformer/translation",
                                                 translation, self.epoch)

            print(
                f'Avg Val Loss: {val_loss_avg} | Val Perplexity: {math.exp(val_loss_avg)} | Time: {val_mins}m {val_secs}s'
            )
            print('\n')

            # use a scheduler in order to decay learning rate hasn't improved
            if self.scheduler is not None:
                self.scheduler.step(val_loss_avg)

            is_best = val_loss_avg < self.best_val_loss

            optim_dict = self.optimizer._optimizer.state_dict() if isinstance(
                self.optimizer,
                ScheduledOptimizer) else self.optimizer.state_dict()

            # save checkpoint
            self.save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "state_dict": self.model.state_dict(),
                    "optim_dict": optim_dict
                },
                is_best=is_best,
                checkpoint=self.params.model_dir + "/checkpoints/")

            if is_best:
                print("- Found new lowest loss!")
                self.best_val_loss = val_loss_avg
Exemplo n.º 2
0
def main(params, greedy, beam_size, test):
    """
    The main function for decoding a trained MT model
    Arguments:
        params: parameters related to the `model` that is being decoded
        greedy: whether or not to do greedy decoding
        beam_size: size of beam if doing beam search
    """
    print("Loading dataset...")
    _, dev_iter, test_iterator, DE, EN = load_dataset(params.data_path,
                                                      params.train_batch_size,
                                                      params.dev_batch_size)
    de_size, en_size = len(DE.vocab), len(EN.vocab)
    print("[DE Vocab Size: ]: {}, [EN Vocab Size]: {}".format(
        de_size, en_size))

    params.src_vocab_size = de_size
    params.tgt_vocab_size = en_size
    params.sos_index = EN.vocab.stoi["<s>"]
    params.pad_token = EN.vocab.stoi["<pad>"]
    params.eos_index = EN.vocab.stoi["</s>"]
    params.itos = EN.vocab.itos

    device = torch.device('cuda' if params.cuda else 'cpu')
    params.device = device

    # make the Seq2Seq model
    model = make_seq2seq_model(params)

    # load the saved model for evaluation
    if params.average > 1:
        print("Averaging the last {} checkpoints".format(params.average))
        checkpoint = {}
        checkpoint["state_dict"] = average_checkpoints(params.model_dir,
                                                       params.average)
        model = Trainer.load_checkpoint(model, checkpoint)
    else:
        model_path = os.path.join(params.model_dir + "checkpoints/",
                                  params.model_file)
        print("Restoring parameters from {}".format(model_path))
        model = Trainer.load_checkpoint(model, model_path)

    # evaluate on the test set
    if test:
        print("Doing Beam Search on the Test Set")
        test_decoder = Translator(model, test_iterator, params, device)
        test_beam_search_outputs = test_decoder.beam_decode(
            beam_width=beam_size)
        test_decoder.output_decoded_translations(
            test_beam_search_outputs,
            "beam_search_outputs_size_test={}.en".format(beam_size))
        return

    # instantiate a Translator object to translate SRC langauge to TRG language using Greedy/Beam Decoding
    decoder = Translator(model, dev_iter, params, device)

    if greedy:
        print("Doing Greedy Decoding...")
        greedy_outputs = decoder.greedy_decode(max_len=100)
        decoder.output_decoded_translations(greedy_outputs,
                                            "greedy_outputs.en")

        print("Evaluating BLEU Score on Greedy Tranlsation...")
        subprocess.call([
            './utils/eval.sh', params.model_dir + "outputs/greedy_outputs.en"
        ])

    if beam_size:
        print("Doing Beam Search...")
        beam_search_outputs = decoder.beam_decode(beam_width=beam_size)
        decoder.output_decoded_translations(
            beam_search_outputs,
            "beam_search_outputs_size={}.en".format(beam_size))

        print("Evaluating BLEU Score on Beam Search Translation")
        subprocess.call([
            './utils/eval.sh', params.model_dir +
            "outputs/beam_search_outputs_size={}.en".format(beam_size)
        ])