Пример #1
0
    def testIteratorBatchType(self):

        current_cfg = self.data_cfg.copy()
        current_cfg["level"] = "word"
        current_cfg["lowercase"] = False

        # load toy data
        data = load_data(current_cfg)
        train_data = data["train_data"]
        dev_data = data["dev_data"]
        test_data = data["test_data"]
        vocabs = data["vocabs"]
        src_vocab = vocabs["src"]
        trg_vocab = vocabs["trg"]

        # make batches by number of sentences
        train_iter = iter(make_data_iter(
            train_data, batch_size=10, batch_type="sentence"))
        batch = next(train_iter)

        self.assertEqual(batch.src[0].shape[0], 10)
        self.assertEqual(batch.trg[0].shape[0], 10)

        # make batches by number of tokens
        train_iter = iter(make_data_iter(
            train_data, batch_size=100, batch_type="token"))
        _ = next(train_iter)  # skip a batch
        _ = next(train_iter)  # skip another batch
        batch = next(train_iter)

        self.assertEqual(batch.src[0].shape[0], 8)
        self.assertEqual(np.prod(batch.src[0].shape), 88)
        self.assertLessEqual(np.prod(batch.src[0].shape), 100)
Пример #2
0
    def testIteratorBatchType(self):

        current_cfg = self.data_cfg.copy()

        # load toy data
        train_data, dev_data, test_data, src_vocab, trg_vocab = \
            load_data(current_cfg)

        # make batches by number of sentences
        train_iter = iter(
            make_data_iter(train_data, batch_size=10, batch_type="sentence"))
        batch = next(train_iter)

        self.assertEqual(batch.src[0].shape[0], 10)
        self.assertEqual(batch.trg[0].shape[0], 10)

        # make batches by number of tokens
        train_iter = iter(
            make_data_iter(train_data, batch_size=100, batch_type="token"))
        _ = next(train_iter)  # skip a batch
        _ = next(train_iter)  # skip another batch
        batch = next(train_iter)

        self.assertEqual(batch.src[0].shape[0], 8)
        self.assertEqual(np.prod(batch.src[0].shape), 88)
        self.assertLessEqual(np.prod(batch.src[0].shape), 100)
Пример #3
0
    def fast_adapt(self, task, learner, valid=False):
        loss = 0.0
        batch_size = self.batch_size
        steps = self.adaptation_steps
        if valid:
            batch_size = self.valid_batch_size
            steps = 1  # Take only one batch during validation

        task_iter = make_data_iter(task,
                                   batch_size=batch_size,
                                   batch_type=self.batch_type,
                                   train=True,
                                   shuffle=self.shuffle)
        for i in range(steps):
            batch = next(iter(task_iter))
            train_batch = self.batch_class(batch,
                                           self.model.pad_index,
                                           use_cuda=self.use_cuda)
            batch_loss, _, _, _ = learner(return_type="loss",
                                          **vars(train_batch))

            learner.adapt(batch_loss, allow_nograd=True,
                          allow_unused=True)  # Adapt learner after every batch
            loss += batch_loss

        loss /= (batch_size * steps)
        #adapt after every task
        #learner.adapt(loss,allow_nograd=True,allow_unused=True)
        print("Loss")
        print(loss)
        return loss
Пример #4
0
    def testBatchTrainIterator(self):

        batch_size = 4
        self.assertEqual(len(self.train_data), 27)

        # make data iterator
        # *note*: BucketIterator is replaced with Iterator
        train_iter = make_data_iter(self.train_data,
                                    train=True,
                                    shuffle=True,
                                    batch_size=batch_size)
        self.assertEqual(train_iter.batch_size, batch_size)
        self.assertTrue(train_iter.shuffle)
        self.assertTrue(train_iter.train)
        self.assertEqual(train_iter.epoch, 0)
        self.assertEqual(train_iter.iterations, 0)

        expected_src0 = torch.Tensor([
            [
                18, 8, 6, 26, 5, 4, 10, 6, 28, 8, 17, 11, 22, 5, 19, 14, 4, 12,
                25, 3
            ],
            [
                19, 11, 30, 5, 18, 23, 13, 4, 12, 5, 21, 4, 12, 7, 23, 17, 11,
                9, 3, 1
            ],
            [19, 11, 22, 5, 8, 11, 5, 29, 8, 22, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1],
            [14, 8, 6, 15, 4, 9, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
        ]).long()
        expected_src0_len = torch.Tensor([20, 19, 11, 7]).long()
        expected_trg0 = torch.Tensor(
            [[
                14, 8, 21, 12, 4, 11, 6, 12, 13, 22, 4, 14, 12, 10, 21, 8, 4,
                14, 8, 23, 3
            ],
             [
                 5, 7, 30, 4, 20, 5, 5, 19, 4, 20, 5, 14, 10, 20, 9, 3, 1, 1,
                 1, 1, 1
             ],
             [5, 7, 22, 4, 7, 6, 7, 9, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
             [
                 8, 7, 6, 10, 17, 4, 13, 5, 15, 9, 3, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1
             ]]).long()
        expected_trg0_len = torch.Tensor([22, 17, 10, 12]).long()

        total_samples = 0
        for b in iter(train_iter):
            b = Batch(torch_batch=b, pad_index=self.pad_index)
            if total_samples == 0:
                self.assertTensorEqual(b.src, expected_src0)
                self.assertTensorEqual(b.src_length, expected_src0_len)
                self.assertTensorEqual(b.trg, expected_trg0)
                self.assertTensorEqual(b.trg_length, expected_trg0_len)
            total_samples += b.nseqs
            self.assertLessEqual(b.nseqs, batch_size)
        self.assertEqual(total_samples, len(self.train_data))
Пример #5
0
    def testBatchDevIterator(self):

        batch_size = 3
        self.assertEqual(len(self.dev_data), 20)

        # make data iterator
        dev_iter = make_data_iter(self.dev_data, train=False, shuffle=False,
                                  batch_size=batch_size)
        self.assertEqual(dev_iter.batch_size, batch_size)
        self.assertFalse(dev_iter.shuffle)
        self.assertFalse(dev_iter.train)
        self.assertEqual(dev_iter.epoch, 0)
        self.assertEqual(dev_iter.iterations, 0)

        expected_src0 = torch.Tensor(
            [[29, 8, 5, 22, 5, 8, 16, 7, 19, 5, 22, 5, 24, 8, 7, 5, 7, 19,
              16, 16, 5, 31, 10, 19, 11, 8, 17, 15, 10, 6, 18, 5, 7, 4, 10, 6,
              5, 25, 3],
             [10, 17, 11, 5, 28, 12, 4, 23, 4, 5, 0, 10, 17, 11, 5, 22, 5, 14,
              8, 7, 7, 5, 10, 17, 11, 5, 14, 8, 5, 31, 10, 6, 5, 9, 3, 1,
              1, 1, 1],
             [29, 8, 5, 22, 5, 18, 23, 13, 4, 6, 5, 13, 8, 18, 5, 9, 3, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1]]).long()
        expected_src0_len = torch.Tensor([39, 35, 17]).long()
        expected_trg0 = torch.Tensor(
            [[13, 11, 12, 4, 22, 4, 12, 5, 4, 22, 4, 25, 7, 6, 8, 4, 14, 12,
              4, 24, 14, 5, 7, 6, 26, 17, 14, 10, 20, 4, 23, 3],
             [14, 0, 28, 4, 7, 6, 18, 18, 13, 4, 8, 5, 4, 24, 11, 4, 7, 11,
              16, 11, 4, 9, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1],
             [13, 11, 12, 4, 22, 4, 7, 11, 27, 27, 5, 4, 9, 3, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]).long()
        expected_trg0_len = torch.Tensor([33, 24, 15]).long()

        total_samples = 0
        for b in iter(dev_iter):
            self.assertEqual(type(b), TorchTBatch)
            b = Batch(b, pad_index=self.pad_index)

            # test the sorting by src length
            self.assertEqual(type(b), Batch)
            before_sort = b.src_lengths
            b.sort_by_src_lengths()
            after_sort = b.src_lengths
            self.assertTensorEqual(torch.sort(before_sort, descending=True)[0],
                                   after_sort)
            self.assertEqual(type(b), Batch)

            if total_samples == 0:
                self.assertTensorEqual(b.src, expected_src0)
                self.assertTensorEqual(b.src_lengths, expected_src0_len)
                self.assertTensorEqual(b.trg, expected_trg0)
                self.assertTensorEqual(b.trg_lengths, expected_trg0_len)
            total_samples += b.nseqs
            self.assertLessEqual(b.nseqs, batch_size)
        self.assertEqual(total_samples, len(self.dev_data))
Пример #6
0
    def testBatchTrainIterator(self):

        batch_size = 4
        self.assertEqual(len(self.train_data), 27)

        # make data iterator
        train_iter = make_data_iter(self.train_data,
                                    train=True,
                                    shuffle=True,
                                    batch_size=batch_size)
        self.assertEqual(train_iter.batch_size, batch_size)
        self.assertTrue(train_iter.shuffle)
        self.assertTrue(train_iter.train)
        self.assertEqual(train_iter.epoch, 0)
        self.assertEqual(train_iter.iterations, 0)

        expected_src0 = torch.Tensor([[
            21, 10, 4, 16, 4, 5, 21, 4, 12, 33, 6, 14, 4, 12, 23, 6, 18, 4, 6,
            9, 3
        ],
                                      [
                                          20, 28, 4, 10, 28, 4, 6, 5, 14, 8, 6,
                                          15, 4, 5, 7, 17, 11, 27, 6, 9, 3
                                      ],
                                      [
                                          24, 8, 7, 5, 24, 10, 12, 14, 5, 18,
                                          4, 7, 17, 11, 4, 11, 4, 6, 25, 3, 1
                                      ]]).long()
        expected_src0_len = torch.Tensor([21, 21, 20]).long()
        expected_trg0 = torch.Tensor(
            [[6, 4, 27, 5, 8, 4, 5, 31, 4, 26, 7, 6, 10, 20, 11, 9, 3],
             [8, 7, 6, 10, 17, 4, 13, 5, 15, 9, 3, 1, 1, 1, 1, 1, 1],
             [12, 5, 4, 25, 7, 6, 8, 4, 7, 6, 18, 18, 11, 10, 12, 23,
              3]]).long()
        expected_trg0_len = torch.Tensor([18, 12, 18]).long()

        total_samples = 0
        for b in iter(train_iter):
            b = Batch(torch_batch=b, pad_index=self.pad_index)
            if total_samples == 0:
                src, src_lengths, _ = b["src"]
                trg, trg_lengths, _ = b["trg"]
                self.assertTensorEqual(src, expected_src0)
                self.assertTensorEqual(src_lengths, expected_src0_len)
                self.assertTensorEqual(trg[:, 1:], expected_trg0)
                self.assertTensorEqual(trg_lengths, expected_trg0_len)
            total_samples += b.nseqs
            self.assertLessEqual(b.nseqs, batch_size)
        self.assertEqual(total_samples, len(self.train_data))
Пример #7
0
 def compute_task_loss(self, train_task, learner):
     loss = 0.0
     task_train_iter = make_data_iter(train_task,
                                      batch_size=self.batch_size,
                                      batch_type=self.batch_type,
                                      train=True,
                                      shuffle=self.shuffle)
     for i, batch in enumerate(iter(task_train_iter)):
         train_batch = self.batch_class(batch,
                                        self.model.pad_index,
                                        use_cuda=self.use_cuda)
         batch_loss, _, _, _ = learner(return_type="loss",
                                       **vars(train_batch))
         loss += batch_loss
     loss /= len(train_task)
     print("Task Loss", loss)
     return loss
Пример #8
0
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset) \
            -> None:
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        """
        train_iter = make_data_iter(train_data,
                                    batch_size=self.batch_size,
                                    batch_type=self.batch_type,
                                    train=True,
                                    shuffle=self.shuffle)

        # For last batch in epoch batch_multiplier needs to be adjusted
        # to fit the number of leftover training examples
        leftover_batch_size = len(train_data) % (self.batch_multiplier *
                                                 self.batch_size)

        for epoch_no in range(self.epochs):
            self.logger.info("EPOCH %d", epoch_no + 1)

            if self.scheduler is not None and self.scheduler_step_at == "epoch":
                self.scheduler.step(epoch=epoch_no)

            self.model.train()

            # Reset statistics for each epoch.
            start = time.time()
            total_valid_duration = 0
            start_tokens = self.total_tokens
            self.current_batch_multiplier = self.batch_multiplier
            self.optimizer.zero_grad()
            count = self.current_batch_multiplier - 1
            epoch_loss = 0

            for i, batch in enumerate(iter(train_iter)):
                # reactivate training
                self.model.train()
                # create a Batch object from torchtext batch
                batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda)

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/
                # increasing-mini-batch-size-without-increasing-
                # memory-6794e10db672

                # Set current_batch_mutliplier to fit
                # number of leftover examples for last batch in epoch
                # Only works if batch_type == sentence
                if self.batch_type == "sentence":
                    if self.batch_multiplier > 1 and i == len(train_iter) - \
                            math.ceil(leftover_batch_size / self.batch_size):
                        self.current_batch_multiplier = math.ceil(
                            leftover_batch_size / self.batch_size)
                        count = self.current_batch_multiplier - 1

                update = count == 0
                # print(count, update, self.steps)
                batch_loss = self._train_batch(batch,
                                               update=update,
                                               count=count)

                # Only save finaly computed batch_loss of full batch
                if update:
                    self.tb_writer.add_scalar("train/train_batch_loss",
                                              batch_loss, self.steps)

                count = self.batch_multiplier if update else count
                count -= 1

                # Only add complete batch_loss of full mini-batch to epoch_loss
                if update:
                    epoch_loss += batch_loss.detach().cpu().numpy()

                if self.scheduler is not None and \
                        self.scheduler_step_at == "step" and update:
                    self.scheduler.step()

                # log learning progress
                if self.steps % self.logging_freq == 0 and update:
                    elapsed = time.time() - start - total_valid_duration
                    elapsed_tokens = self.total_tokens - start_tokens
                    self.logger.info(
                        "Epoch %3d Step: %8d Batch Loss: %12.6f "
                        "Tokens per Sec: %8.0f, Lr: %.6f", epoch_no + 1,
                        self.steps, batch_loss, elapsed_tokens / elapsed,
                        self.optimizer.param_groups[0]["lr"])
                    start = time.time()
                    total_valid_duration = 0
                    start_tokens = self.total_tokens

                # validate on the entire dev set
                if self.steps % self.validation_freq == 0 and update:
                    valid_start_time = time.time()

                    valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores = \
                        validate_on_data(
                            logger=self.logger,
                            batch_size=self.eval_batch_size,
                            data=valid_data,
                            eval_metric=self.eval_metric,
                            level=self.level, model=self.model,
                            use_cuda=self.use_cuda,
                            max_output_length=self.max_output_length,
                            loss_function=self.loss,
                            beam_size=1,  # greedy validations
                            batch_type=self.eval_batch_type,
                            postprocess=True # always remove BPE for validation
                        )

                    self.tb_writer.add_scalar("valid/valid_loss", valid_loss,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_score", valid_score,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl,
                                              self.steps)

                    if self.early_stopping_metric == "loss":
                        ckpt_score = valid_loss
                    elif self.early_stopping_metric in ["ppl", "perplexity"]:
                        ckpt_score = valid_ppl
                    else:
                        ckpt_score = valid_score

                    new_best = False
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                            'Hooray! New best validation result [%s]!',
                            self.early_stopping_metric)
                        if self.ckpt_queue.maxsize > 0:
                            self.logger.info("Saving new checkpoint.")
                            new_best = True
                            self._save_checkpoint()

                    if self.scheduler is not None \
                            and self.scheduler_step_at == "validation":
                        self.scheduler.step(ckpt_score)

                    # append to validation report
                    self._add_report(valid_score=valid_score,
                                     valid_loss=valid_loss,
                                     valid_ppl=valid_ppl,
                                     eval_metric=self.eval_metric,
                                     new_best=new_best)

                    self._log_examples(
                        sources_raw=[v for v in valid_sources_raw],
                        sources=valid_sources,
                        hypotheses_raw=valid_hypotheses_raw,
                        hypotheses=valid_hypotheses,
                        references=valid_references)

                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration
                    self.logger.info(
                        'Validation result (greedy) at epoch %3d, '
                        'step %8d: %s: %6.2f, loss: %8.4f, ppl: %8.4f, '
                        'duration: %.4fs', epoch_no + 1, self.steps,
                        self.eval_metric, valid_score, valid_loss, valid_ppl,
                        valid_duration)

                    # store validation set outputs
                    self._store_outputs(valid_hypotheses)

                    # store attention plots for selected valid sentences
                    if valid_attention_scores:
                        store_attention_plots(
                            attentions=valid_attention_scores,
                            targets=valid_hypotheses_raw,
                            sources=[s for s in valid_data.src],
                            indices=self.log_valid_sents,
                            output_prefix="{}/att.{}".format(
                                self.model_dir, self.steps),
                            tb_writer=self.tb_writer,
                            steps=self.steps)

                if self.stop:
                    break
            if self.stop:
                self.logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break

            self.logger.info('Epoch %3d: total training loss %.2f',
                             epoch_no + 1, epoch_loss)
        else:
            self.logger.info('Training ended after %3d epochs.', epoch_no + 1)
        self.logger.info(
            'Best validation result (greedy) at step '
            '%8d: %6.2f %s.', self.best_ckpt_iteration, self.best_ckpt_score,
            self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer
Пример #9
0
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset) \
            -> None:
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        """
        self.train_iter = make_data_iter(train_data,
                                         batch_size=self.batch_size,
                                         batch_type=self.batch_type,
                                         train=True,
                                         shuffle=self.shuffle)

        if self.train_iter_state is not None:
            self.train_iter.load_state_dict(self.train_iter_state)

        #################################################################
        # simplify accumulation logic:
        #################################################################
        # for epoch in range(epochs):
        #     self.model.zero_grad()
        #     epoch_loss = 0.0
        #     batch_loss = 0.0
        #     for i, batch in enumerate(iter(self.train_iter)):
        #
        #         # gradient accumulation:
        #         # loss.backward() inside _train_step()
        #         batch_loss += self._train_step(inputs)
        #
        #         if (i + 1) % self.batch_multiplier == 0:
        #             self.optimizer.step()     # update!
        #             self.model.zero_grad()    # reset gradients
        #             self.steps += 1           # increment counter
        #
        #             epoch_loss += batch_loss  # accumulate batch loss
        #             batch_loss = 0            # reset batch loss
        #
        #     # leftovers are just ignored.
        #################################################################

        logger.info(
            "Train stats:\n"
            "\tdevice: %s\n"
            "\tn_gpu: %d\n"
            "\t16-bits training: %r\n"
            "\tgradient accumulation: %d\n"
            "\tbatch size per device: %d\n"
            "\ttotal batch size (w. parallel & accumulation): %d", self.device,
            self.n_gpu, self.fp16, self.batch_multiplier, self.batch_size //
            self.n_gpu if self.n_gpu > 1 else self.batch_size,
            self.batch_size * self.batch_multiplier)

        for epoch_no in range(self.epochs):
            logger.info("EPOCH %d", epoch_no + 1)

            if self.scheduler is not None and self.scheduler_step_at == "epoch":
                self.scheduler.step(epoch=epoch_no)

            self.model.train()

            # Reset statistics for each epoch.
            start = time.time()
            total_valid_duration = 0
            start_tokens = self.stats.total_tokens
            self.model.zero_grad()
            epoch_loss = 0
            batch_loss = 0

            for i, batch in enumerate(iter(self.train_iter)):
                # create a Batch object from torchtext batch
                batch = self.batch_class(batch,
                                         self.model.pad_index,
                                         use_cuda=self.use_cuda)

                # get batch loss
                batch_loss += self._train_step(batch)

                # update!
                if (i + 1) % self.batch_multiplier == 0:
                    # clip gradients (in-place)
                    if self.clip_grad_fun is not None:
                        if self.fp16:
                            self.clip_grad_fun(
                                params=amp.master_params(self.optimizer))
                        else:
                            self.clip_grad_fun(params=self.model.parameters())

                    # make gradient step
                    self.optimizer.step()

                    # decay lr
                    if self.scheduler is not None \
                            and self.scheduler_step_at == "step":
                        self.scheduler.step()

                    # reset gradients
                    self.model.zero_grad()

                    # increment step counter
                    self.stats.steps += 1

                    # log learning progress
                    if self.stats.steps % self.logging_freq == 0:
                        self.tb_writer.add_scalar("train/train_batch_loss",
                                                  batch_loss, self.stats.steps)
                        elapsed = time.time() - start - total_valid_duration
                        elapsed_tokens = self.stats.total_tokens - start_tokens
                        logger.info(
                            "Epoch %3d, Step: %8d, Batch Loss: %12.6f, "
                            "Tokens per Sec: %8.0f, Lr: %.6f", epoch_no + 1,
                            self.stats.steps, batch_loss,
                            elapsed_tokens / elapsed,
                            self.optimizer.param_groups[0]["lr"])
                        start = time.time()
                        total_valid_duration = 0
                        start_tokens = self.stats.total_tokens

                    # Only add complete loss of full mini-batch to epoch_loss
                    epoch_loss += batch_loss  # accumulate epoch_loss
                    batch_loss = 0  # rest batch_loss

                    # validate on the entire dev set
                    if self.stats.steps % self.validation_freq == 0:
                        valid_duration = self._validate(valid_data, epoch_no)
                        total_valid_duration += valid_duration

                if self.stats.stop:
                    break
            if self.stats.stop:
                logger.info('Training ended since minimum lr %f was reached.',
                            self.learning_rate_min)
                break

            logger.info('Epoch %3d: total training loss %.2f', epoch_no + 1,
                        epoch_loss)
        else:
            logger.info('Training ended after %3d epochs.', epoch_no + 1)
        logger.info('Best validation result (greedy) at step %8d: %6.2f %s.',
                    self.stats.best_ckpt_iter, self.stats.best_ckpt_score,
                    self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer
Пример #10
0
def validate_on_data(model: Model, data: Dataset,
                     batch_size: int,
                     use_cuda: bool, max_output_length: int,
                     level: str, eval_metric: Optional[str],
                     n_gpu: int,
                     batch_class: Batch = Batch,
                     compute_loss: bool = False,
                     beam_size: int = 1, beam_alpha: int = -1,
                     batch_type: str = "sentence",
                     postprocess: bool = True,
                     bpe_type: str = "subword-nmt",
                     sacrebleu: dict = None,
                     n_best: int = 1) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `compute_loss` is True and references are given,
    also compute the loss.

    :param model: model module
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param batch_class: class type of batch
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param level: segmentation level, one of "char", "bpe", "word"
    :param eval_metric: evaluation metric, e.g. "bleu"
    :param n_gpu: number of GPUs
    :param compute_loss: whether to computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If <2 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to -1 (default).
    :param batch_type: validation batch type (sentence or token)
    :param postprocess: if True, remove BPE segmentation from translations
    :param bpe_type: bpe type, one of {"subword-nmt", "sentencepiece"}
    :param sacrebleu: sacrebleu options
    :param n_best: Amount of candidates to return

    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
    """
    assert batch_size >= n_gpu, "batch_size must be bigger than n_gpu."
    if sacrebleu is None:  # assign default value
        sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}
    if batch_size > 1000 and batch_type == "sentence":
        logger.warning(
            "WARNING: Are you sure you meant to work on huge batches like "
            "this? 'batch_size' is > 1000 for sentence-batching. "
            "Consider decreasing it or switching to"
            " 'eval_batch_type: token'.")
    valid_iter = make_data_iter(dataset=data,
                                batch_size=batch_size,
                                batch_type=batch_type,
                                shuffle=False,
                                train=False)
    valid_sources_raw = data.src
    pad_index = model.src_vocab.stoi[PAD_TOKEN]
    # disable dropout
    model.eval()
    # don't track gradients during validation
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = []
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = batch_class(valid_batch, pad_index, use_cuda=use_cuda)
            # sort batch now by src length and keep track of order
            reverse_index = batch.sort_by_src_length()
            sort_reverse_index = expand_reverse_index(reverse_index, n_best)

            # run as during training with teacher forcing
            if compute_loss and batch.trg is not None:
                batch_loss, _, _, _ = model(return_type="loss", **vars(batch))
                if n_gpu > 1:
                    batch_loss = batch_loss.mean()  # average on multi-gpu
                total_loss += batch_loss
                total_ntokens += batch.ntokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores = run_batch(
                model=model,
                batch=batch,
                beam_size=beam_size,
                beam_alpha=beam_alpha,
                max_output_length=max_output_length,
                n_best=n_best)

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])
            valid_attention_scores.extend(
                attention_scores[sort_reverse_index]
                if attention_scores is not None else [])

        assert len(all_outputs) == len(data) * n_best

        if compute_loss and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            # exponent of token-level negative log prob
            valid_ppl = torch.exp(total_loss / total_ntokens)
        else:
            valid_loss = -1
            valid_ppl = -1

        # decode back to symbols
        decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs,
                                                            cut_at_eos=True)

        # evaluate with metric on full dataset
        join_char = " " if level in ["word", "bpe"] else ""
        valid_sources = [join_char.join(s) for s in data.src]
        valid_references = [join_char.join(t) for t in data.trg]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid]

        # post-process
        if level == "bpe" and postprocess:
            valid_sources = [
                bpe_postprocess(s, bpe_type=bpe_type) for s in valid_sources
            ]
            valid_references = [
                bpe_postprocess(v, bpe_type=bpe_type) for v in valid_references
            ]
            valid_hypotheses = [
                bpe_postprocess(v, bpe_type=bpe_type) for v in valid_hypotheses
            ]

        # if references are given, evaluate against them
        if valid_references:
            assert len(valid_hypotheses) == len(valid_references)

            current_valid_score = 0
            if eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(valid_hypotheses,
                                           valid_references,
                                           tokenize=sacrebleu["tokenize"])
            elif eval_metric.lower() == 'chrf':
                current_valid_score = chrf(
                    valid_hypotheses,
                    valid_references,
                    remove_whitespace=sacrebleu["remove_whitespace"])
            elif eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(  # supply List[List[str]]
                    list(decoded_valid), list(data.trg))
            elif eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)
        else:
            current_valid_score = -1

    return current_valid_score, valid_loss, valid_ppl, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores
Пример #11
0
def validate_on_data(model: Model,
                     data: Dataset,
                     batch_size: int,
                     use_cuda: bool,
                     max_output_length: int,
                     trg_level: str,
                     eval_metrics: Optional[Sequence[str]],
                     loss_function: torch.nn.Module = None,
                     beam_size: int = 0,
                     force_prune_size: int = 5,
                     beam_alpha: int = 0,
                     batch_type: str = "sentence",
                     save_attention: bool = False,
                     validate_by_label: bool = False,
                     forced_sparsity: bool = False,
                     method=None,
                     max_hyps=1,
                     break_at_p: float = 1.0,
                     break_at_argmax: bool = False,
                     short_depth: int = 0):
    """
    Generate translations for the given data.
    If `loss_function` is not None and references are given,
    also compute the loss.

    :param model:
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda:
    :param max_output_length: maximum length for generated hypotheses
    :param trg_level: target segmentation level
    :param eval_metrics:
    :param loss_function: loss function that computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation (default 0 is greedy)
    :param beam_alpha: beam search alpha for length penalty (default 0)
    :param batch_type: validation batch type (sentence or token)

    :return:
        - current_valid_scores: current validation score [eval_metric],
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
    """
    if beam_size > 0:
        force_prune_size = beam_size

    if validate_by_label:
        assert isinstance(data, TSVDataset) and data.label_columns

    valid_scores = defaultdict(float)  # container for scores
    stats = defaultdict(float)

    valid_iter = make_data_iter(dataset=data,
                                batch_size=batch_size,
                                batch_type=batch_type,
                                shuffle=False,
                                train=False,
                                use_cuda=use_cuda)

    pad_index = model.trg_vocab.stoi[PAD_TOKEN]

    model.eval()  # disable dropout

    force_objectives = loss_function is not None or forced_sparsity

    # possible tasks are: force w/ gold, force w/ empty, search
    scorer = partial(len_penalty, alpha=beam_alpha) if beam_alpha > 0 else None
    confidences = []
    corrects = []
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = defaultdict(list)
        for valid_batch in iter(valid_iter):
            batch = Batch(valid_batch, pad_index)
            rev_index = batch.sort_by_src_lengths()

            encoder_output, _ = model.encode(batch)

            empty_probs = None
            if force_objectives and not isinstance(model, EnsembleModel):
                # compute all the logits.
                logits = model.force_decode(batch, encoder_output)[0]
                bsz, gold_len, vocab_size = logits.size()
                gold, gold_lengths, _ = batch["trg"]
                prediction_steps = gold_lengths.sum().item() - bsz
                assert gold.size(0) == bsz

                if loss_function is not None:
                    gold_pred = gold[:, 1:].contiguous().view(-1)
                    batch_loss = loss_function(
                        logits.view(-1, logits.size(-1)), gold_pred)
                    valid_scores["loss"] += batch_loss

                if forced_sparsity:
                    # compute probabilities
                    out = logits.view(-1, vocab_size)
                    if isinstance(model, EnsembleModel):
                        probs = out
                    else:
                        probs = model.decoder.gen_func(out, dim=-1)

                    # Compute numbers derived from the distributions.
                    # This includes support size, entropy, and calibration
                    non_pad = (gold[:, 1:] != pad_index).view(-1)
                    real_probs = probs[non_pad]
                    n_supported = real_probs.gt(0).sum().item()
                    pred_ps, pred_ix = real_probs.max(dim=-1)
                    real_gold = gold[:, 1:].contiguous().view(-1)[non_pad]
                    real_correct = pred_ix.eq(real_gold)
                    corrects.append(real_correct)
                    confidences.append(pred_ps)

                    beam_probs, _ = real_probs.topk(force_prune_size, dim=-1)
                    pruned_mass = 1 - beam_probs.sum(dim=-1)
                    stats["force_pruned_mass"] += pruned_mass.sum().item()

                    # compute stuff with the empty sequence
                    empty_probs = probs.view(bsz, gold_len,
                                             vocab_size)[:, 0, model.eos_index]
                    assert empty_probs.size() == gold_lengths.size()
                    empty_possible = empty_probs.gt(0).sum().item()
                    empty_mass = empty_probs.sum().item()

                    stats["eos_supported"] += empty_possible
                    stats["eos_mass"] += empty_mass
                    stats["n_supp"] += n_supported
                    stats["n_pred"] += prediction_steps

                short_scores = None
                if short_depth > 0:
                    # we call run_batch again with the short depth. We don't
                    # really care what the hypotheses are, we only want the
                    # scores
                    _, _, short_scores = model.run_batch(
                        batch=batch,
                        beam_size=beam_size,  # can this be removed?
                        scorer=scorer,  # should be none
                        max_output_length=short_depth,
                        method="dfs",
                        max_hyps=max_hyps,
                        encoder_output=encoder_output,
                        return_scores=True)

            # run as during inference to produce translations
            # todo: return_scores for greedy
            output, attention_scores, beam_scores = model.run_batch(
                batch=batch,
                beam_size=beam_size,
                scorer=scorer,
                max_output_length=max_output_length,
                method=method,
                max_hyps=max_hyps,
                encoder_output=encoder_output,
                return_scores=True,
                break_at_argmax=break_at_argmax,
                break_at_p=break_at_p)
            stats["hyp_length"] += output.ne(model.pad_index).sum().item()
            if beam_scores is not None and empty_probs is not None:
                # I need to expand this to handle stuff up to length m.
                # note that although you can compute the probability of the
                # empty sequence without any extra computation, you *do* need
                # to do extra decoding if you want to get the most likely
                # sequence with length <= m.
                empty_better = empty_probs.log().gt(beam_scores).sum().item()
                stats["empty_better"] += empty_better

                if short_scores is not None:
                    short_better = short_scores.gt(beam_scores).sum().item()
                    stats["short_better"] += short_better

            # sort outputs back to original order
            all_outputs.extend(output[rev_index])

            if save_attention and attention_scores is not None:
                # beam search currently does not support attention logging
                for k, v in attention_scores.items():
                    valid_attention_scores[k].extend(v[rev_index])

        assert len(all_outputs) == len(data)

    ref_length = sum(len(d.trg) for d in data)
    valid_scores["length_ratio"] = stats["hyp_length"] / ref_length

    assert len(corrects) == len(confidences)
    if corrects:
        valid_scores["ece"] = expected_calibration_error(corrects, confidences)

    if stats["n_pred"] > 0:
        valid_scores["ppl"] = math.exp(valid_scores["loss"] / stats["n_pred"])

    if forced_sparsity and stats["n_pred"] > 0:
        valid_scores["support"] = stats["n_supp"] / stats["n_pred"]
        valid_scores["empty_possible"] = stats["eos_supported"] / len(
            all_outputs)
        valid_scores["empty_prob"] = stats["eos_mass"] / len(all_outputs)
        valid_scores[
            "force_pruned_mass"] = stats["force_pruned_mass"] / stats["n_pred"]
        if beam_size > 0:
            valid_scores["empty_better"] = stats["empty_better"] / len(
                all_outputs)
            if short_depth > 0:
                score_name = "depth_{}_better".format(short_depth)
                valid_scores[score_name] = stats["short_better"] / len(
                    all_outputs)

    # postprocess
    raw_hyps = model.trg_vocab.arrays_to_sentences(all_outputs)
    valid_hyps = postprocess(raw_hyps, trg_level)
    valid_refs = postprocess(data.trg, trg_level)

    # evaluate
    eval_funcs = {
        "bleu": bleu,
        "chrf": chrf,
        "token_accuracy": partial(token_accuracy, level=trg_level),
        "sequence_accuracy": sequence_accuracy,
        "wer": word_error_rate,
        "cer": partial(character_error_rate, level=trg_level),
        "levenshtein_distance": partial(levenshtein_distance, level=trg_level)
    }
    selected_eval_metrics = {name: eval_funcs[name] for name in eval_metrics}
    decoding_scores, scores_by_label = evaluate_decoding(
        data, valid_refs, valid_hyps, selected_eval_metrics, validate_by_label)
    valid_scores.update(decoding_scores)

    return valid_scores, valid_refs, valid_hyps, \
        raw_hyps, valid_attention_scores, scores_by_label
Пример #12
0
def validate_on_data(model: Model, data: Dataset,
                     logger: Logger,
                     batch_size: int,
                     use_cuda: bool, max_output_length: int,
                     level: str, eval_metric: Optional[str],
                     loss_function: torch.nn.Module = None,
                     beam_size: int = 1, beam_alpha: int = -1,
                     batch_type: str = "sentence",
                     postprocess: bool = True
                     ) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `loss_function` is not None and references are given,
    also compute the loss.

    :param model: model module
    :param logger: logger
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param level: segmentation level, one of "char", "bpe", "word"
    :param eval_metric: evaluation metric, e.g. "bleu"
    :param loss_function: loss function that computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If <2 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to -1 (default).
    :param batch_type: validation batch type (sentence or token)
    :param postprocess: if True, remove BPE segmentation from translations

    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
    """
    if batch_size > 1000 and batch_type == "sentence":
        logger.warning(
            "WARNING: Are you sure you meant to work on huge batches like "
            "this? 'batch_size' is > 1000 for sentence-batching. "
            "Consider decreasing it or switching to"
            " 'eval_batch_type: token'.")
    valid_iter = make_data_iter(dataset=data,
                                batch_size=batch_size,
                                batch_type=batch_type,
                                shuffle=False,
                                train=False)
    valid_sources_raw = data.src
    pad_index = model.src_vocab.stoi[PAD_TOKEN]
    # disable dropout
    model.eval()
    # don't track gradients during validation
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = []
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = Batch(valid_batch, pad_index, use_cuda=use_cuda)
            # sort batch now by src length and keep track of order
            sort_reverse_index = batch.sort_by_src_lengths()

            # run as during training with teacher forcing
            if loss_function is not None and batch.trg is not None:
                batch_loss = model.get_loss_for_batch(
                    batch, loss_function=loss_function)
                total_loss += batch_loss
                total_ntokens += batch.ntokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores = model.run_batch(
                batch=batch,
                beam_size=beam_size,
                beam_alpha=beam_alpha,
                max_output_length=max_output_length)

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])
            valid_attention_scores.extend(
                attention_scores[sort_reverse_index]
                if attention_scores is not None else [])

        assert len(all_outputs) == len(data)

        if loss_function is not None and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            # exponent of token-level negative log prob
            valid_ppl = torch.exp(total_loss / total_ntokens)
        else:
            valid_loss = -1
            valid_ppl = -1

        # decode back to symbols
        decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs,
                                                            cut_at_eos=True)

        # evaluate with metric on full dataset
        join_char = " " if level in ["word", "bpe"] else ""
        valid_sources = [join_char.join(s) for s in data.src]
        valid_references = [join_char.join(t) for t in data.trg]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid]

        # post-process
        if level == "bpe" and postprocess:
            valid_sources = [bpe_postprocess(s) for s in valid_sources]
            valid_references = [bpe_postprocess(v) for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses]

        # if references are given, evaluate against them
        if valid_references:
            assert len(valid_hypotheses) == len(valid_references)

            current_valid_score = 0
            if eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'chrf':
                current_valid_score = chrf(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(valid_hypotheses,
                                                     valid_references,
                                                     level=level)
            elif eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)
        else:
            current_valid_score = -1

    return current_valid_score, valid_loss, valid_ppl, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores
Пример #13
0
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset) \
            -> None:
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        """
        train_iter = make_data_iter(train_data,
                                    batch_size=self.batch_size,
                                    train=True,
                                    shuffle=self.shuffle)
        for epoch_no in range(self.epochs):
            self.logger.info("EPOCH %d", epoch_no + 1)

            if self.scheduler is not None and self.scheduler_step_at == "epoch":
                self.scheduler.step(epoch=epoch_no)

            self.model.train()

            start = time.time()
            total_valid_duration = 0
            processed_tokens = self.total_tokens
            count = 0
            epoch_loss = 0

            for batch in iter(train_iter):
                # reactivate training
                self.model.train()
                # create a Batch object from torchtext batch
                batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda)

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/
                # increasing-mini-batch-size-without-increasing-
                # memory-6794e10db672
                update = count == 0
                # print(count, update, self.steps)
                batch_loss = self._train_batch(batch, update=update)
                self.tb_writer.add_scalar("train/train_batch_loss", batch_loss,
                                          self.steps)
                count = self.batch_multiplier if update else count
                count -= 1
                epoch_loss += batch_loss.detach().cpu().numpy()

                # log learning progress
                if self.steps % self.logging_freq == 0 and update:
                    elapsed = time.time() - start - total_valid_duration
                    elapsed_tokens = self.total_tokens - processed_tokens
                    self.logger.info(
                        "Epoch %d Step: %d Batch Loss: %f Tokens per Sec: %f",
                        epoch_no + 1, self.steps, batch_loss,
                        elapsed_tokens / elapsed)
                    start = time.time()
                    total_valid_duration = 0

                # validate on the entire dev set
                if valid_data is not None and \
                    self.steps % self.validation_freq == 0 and update:

                    valid_start_time = time.time()

                    valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores, \
                        valid_logps = validate_on_data(
                            batch_size=self.batch_size, data=valid_data,
                            eval_metric=self.eval_metric,
                            level=self.level, model=self.model,
                            use_cuda=self.use_cuda,
                            max_output_length=self.max_output_length,
                            loss_function=self.loss,
                            return_logp=self.return_logp)

                    self.tb_writer.add_scalar("valid/valid_loss", valid_loss,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_score", valid_score,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl,
                                              self.steps)

                    if self.early_stopping_metric == "loss":
                        ckpt_score = valid_loss
                    elif self.early_stopping_metric in ["ppl", "perplexity"]:
                        ckpt_score = valid_ppl
                    else:
                        ckpt_score = valid_score

                    new_best = False
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                            'Hooray! New best validation result [%s]!',
                            self.early_stopping_metric)
                        if self.ckpt_queue.maxsize > 0:
                            self.logger.info("Saving new checkpoint.")
                            new_best = True
                            self._save_checkpoint()

                    if self.scheduler is not None \
                            and self.scheduler_step_at == "validation":
                        self.scheduler.step(ckpt_score)

                    # append to validation report
                    self._add_report(valid_score=valid_score,
                                     valid_loss=valid_loss,
                                     valid_ppl=valid_ppl,
                                     eval_metric=self.eval_metric,
                                     new_best=new_best)

                    self._log_examples(sources_raw=valid_sources_raw,
                                       sources=valid_sources,
                                       hypotheses_raw=valid_hypotheses_raw,
                                       hypotheses=valid_hypotheses,
                                       references=valid_references)

                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration
                    self.logger.info(
                        'Validation result at epoch %d, step %d: %s: %f, '
                        'loss: %f, ppl: %f, duration: %.4fs', epoch_no + 1,
                        self.steps, self.eval_metric, valid_score, valid_loss,
                        valid_ppl, valid_duration)

                    # store validation set outputs
                    self._store_outputs(
                        valid_hypotheses if self.post_process else
                        [" ".join(v) for v in valid_hypotheses_raw],
                        valid_logps if self.return_logp else None)

                    # store attention plots for selected valid sentences
                    store_attention_plots(attentions=valid_attention_scores,
                                          targets=valid_hypotheses_raw,
                                          sources=[s for s in valid_data.src],
                                          indices=self.log_valid_sents,
                                          output_prefix="{}/att.{}".format(
                                              self.model_dir, self.steps),
                                          tb_writer=self.tb_writer,
                                          steps=self.steps)

                if self.save_freq > 0 and self.steps % self.save_freq == 0:
                    ## Drop checkpoint by number of batches
                    ## Take care of batch multipler in to description
                    self.logger.info("Saving new checkpoint!"
                                     "Batches passed:{}"
                                     "Number of updates:{}".format(
                                         self.batch_multiplier * self.steps,
                                         self.steps))
                    self._save_checkpoint()

                if self.stop:
                    break

            if self.stop:
                self.logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break

            self.logger.info('Epoch %d: total training loss %.2f',
                             epoch_no + 1, epoch_loss)
        else:
            self.logger.info('Training ended after %d epochs.', epoch_no + 1)

        if valid_data is not None:
            self.logger.info('Best validation result at step %d: %f %s.',
                             self.best_ckpt_iteration, self.best_ckpt_score,
                             self.early_stopping_metric)
Пример #14
0
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset):
        """
        Train the model and validate it on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        """
        train_iter = make_data_iter(
            train_data,
            batch_size=self.batch_size,
            batch_type=self.batch_type,
            train=True,
            shuffle=self.shuffle)
        for epoch_no in range(1, self.epochs + 1):
            self.logger.info("EPOCH %d", epoch_no)

            if self.scheduler is not None and self.scheduler_step_at == "epoch":
                self.scheduler.step(epoch=epoch_no - 1)  # 0-based indexing

            self.model.train()

            start = time.time()
            total_valid_duration = 0
            processed_tokens = self.total_tokens
            epoch_loss = 0

            for i, batch in enumerate(iter(train_iter), 1):
                # reactivate training
                self.model.train()
                # create a Batch object from torchtext batch
                batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda)

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/
                # increasing-mini-batch-size-without-increasing-
                # memory-6794e10db672
                update = i % self.batch_multiplier == 0
                batch_loss = self._train_batch(batch, update=update)

                self.log_tensorboard("train", batch_loss=batch_loss)

                epoch_loss += batch_loss.detach().cpu().numpy()

                if self.scheduler is not None and \
                        self.scheduler_step_at == "step" and update:
                    self.scheduler.step()

                # log learning progress
                if self.steps % self.logging_freq == 0 and update:
                    elapsed = time.time() - start - total_valid_duration
                    elapsed_tokens = self.total_tokens - processed_tokens
                    self.logger.info(
                        "Epoch %3d Step: %8d Batch Loss: %12.6f "
                        "Tokens per Sec: %8.0f, Lr: %.6f",
                        epoch_no, self.steps, batch_loss,
                        elapsed_tokens / elapsed,
                        self.optimizer.param_groups[0]["lr"])
                    start = time.time()
                    total_valid_duration = 0
                    processed_tokens = self.total_tokens

                # validate on the entire dev set
                if self.steps % self.validation_freq == 0 and update:
                    valid_start_time = time.time()

                    # it would be nice to include loss and ppl in valid_scores
                    valid_scores, valid_sources, valid_sources_raw, \
                        valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores, \
                        scores_by_lang, by_lang = validate_on_data(
                            batch_size=self.eval_batch_size,
                            data=valid_data,
                            eval_metrics=self.eval_metrics,
                            attn_metrics=self.attn_metrics,
                            src_level=self.src_level,
                            trg_level=self.trg_level,
                            model=self.model,
                            use_cuda=self.use_cuda,
                            max_output_length=self.max_output_length,
                            loss_function=self.loss,
                            beam_size=0,  # greedy validations
                            batch_type=self.eval_batch_type,
                            save_attention=self.plot_attention,
                            log_sparsity=self.log_sparsity,
                            apply_mask=self.valid_apply_mask
                        )

                    ckpt_score = valid_scores[self.early_stopping_metric]
                    self.log_tensorboard("valid", **valid_scores)

                    new_best = False
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                            'Hooray! New best validation result [%s]!',
                            self.early_stopping_metric)
                        if self.ckpt_queue.maxsize > 0:
                            self.logger.info("Saving new checkpoint.")
                            new_best = True
                            self._save_checkpoint()

                    if self.scheduler is not None \
                            and self.scheduler_step_at == "validation":
                        self.scheduler.step(ckpt_score)

                    # append to validation report
                    self._add_report(
                        valid_scores=valid_scores,
                        eval_metrics=self.eval_metrics,
                        new_best=new_best)

                    self._log_examples(
                        sources_raw=valid_sources_raw,
                        sources=valid_sources,
                        hypotheses_raw=valid_hypotheses_raw,
                        hypotheses=valid_hypotheses,
                        references=valid_references
                    )

                    labeled_scores = sorted(valid_scores.items())
                    eval_report = ", ".join("{}: {:.5f}".format(n, v)
                                            for n, v in labeled_scores)

                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration

                    self.logger.info(
                        'Validation result at epoch %3d, step %8d: %s, '
                        'duration: %.4fs',
                        epoch_no, self.steps, eval_report, valid_duration)

                    if scores_by_lang is not None:
                        for metric, scores in scores_by_lang.items():
                            # make a report
                            lang_report = [metric]
                            numbers = sorted(scores.items())
                            lang_report.extend(["{}: {:.5f}".format(k, v)
                                                for k, v in numbers])

                            self.logger.info("\n\t".join(lang_report))

                    # store validation set outputs
                    self._store_outputs(valid_hypotheses)

                    # store attention plots for selected valid sentences
                    if valid_attention_scores and self.plot_attention:
                        store_attention_plots(
                                attentions=valid_attention_scores,
                                sources=[s for s in valid_data.src],
                                targets=valid_hypotheses_raw,
                                indices=self.log_valid_sents,
                                model_dir=self.model_dir,
                                tb_writer=self.tb_writer,
                                steps=self.steps)

                if self.stop:
                    break
            if self.stop:
                self.logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break

            self.logger.info(
                'Epoch %3d: total training loss %.2f', epoch_no, epoch_loss)
        else:
            self.logger.info('Training ended after %3d epochs.', epoch_no)
        self.logger.info('Best validation result at step %8d: %6.2f %s.',
                         self.best_ckpt_iteration, self.best_ckpt_score,
                         self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer
Пример #15
0
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset, kb_task=None, train_kb: TranslationDataset =None,\
        train_kb_lkp: list = [], train_kb_lens: list = [], train_kb_truvals: TranslationDataset=None, valid_kb: Tuple=None, \
        valid_kb_lkp: list=[], valid_kb_lens: list = [], valid_kb_truvals:list=[],
        valid_data_canon: list=[]) \
            -> None:
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        :param kb_task: is not None if kb_task should be executed
        :param train_kb: TranslationDataset holding the loaded train kb data
        :param train_kb_lkp: List with train example index to corresponding kb indices
        :param train_kb_len: List with num of triples per kb 
        :param valid_kb: TranslationDataset holding the loaded valid kb data
        :param valid_kb_lkp: List with valid example index to corresponding kb indices
        :param valid_kb_len: List with num of triples per kb 
        :param valid_kb_truvals: FIXME TODO
        :param valid_data_canon: required to report loss 
        """

        if kb_task:
            train_iter = make_data_iter_kb(train_data,
                                           train_kb,
                                           train_kb_lkp,
                                           train_kb_lens,
                                           train_kb_truvals,
                                           batch_size=self.batch_size,
                                           batch_type=self.batch_type,
                                           train=True,
                                           shuffle=self.shuffle,
                                           canonize=self.model.canonize)
        else:
            train_iter = make_data_iter(train_data,
                                        batch_size=self.batch_size,
                                        batch_type=self.batch_type,
                                        train=True,
                                        shuffle=self.shuffle)

        with torch.autograd.set_detect_anomaly(True):
            for epoch_no in range(self.epochs):
                self.logger.info("EPOCH %d", epoch_no + 1)

                if self.scheduler is not None and self.scheduler_step_at == "epoch":
                    self.scheduler.step(epoch=epoch_no)

                self.model.train()

                start = time.time()
                total_valid_duration = 0
                processed_tokens = self.total_tokens
                count = self.batch_multiplier - 1
                epoch_loss = 0

                for batch in iter(train_iter):
                    # reactivate training
                    self.model.train()

                    # create a Batch object from torchtext batch
                    batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda) if not kb_task else \
                        Batch_with_KB(batch, self.pad_index, use_cuda=self.use_cuda)

                    if kb_task:
                        assert hasattr(batch, "kbsrc"), dir(batch)
                        assert hasattr(batch, "kbtrg"), dir(batch)
                        assert hasattr(batch, "kbtrv"), dir(batch)

                    # only update every batch_multiplier batches
                    # see https://medium.com/@davidlmorton/
                    # increasing-mini-batch-size-without-increasing-
                    # memory-6794e10db672
                    update = count == 0

                    batch_loss = self._train_batch(batch, update=update)

                    if update:
                        self.tb_writer.add_scalar("train/train_batch_loss",
                                                  batch_loss, self.steps)

                    count = self.batch_multiplier if update else count
                    count -= 1
                    epoch_loss += batch_loss.detach().cpu().numpy()

                    if self.scheduler is not None and \
                            self.scheduler_step_at == "step" and update:
                        self.scheduler.step()

                    # log learning progress
                    if self.steps % self.logging_freq == 0 and update:
                        elapsed = time.time() - start - total_valid_duration
                        elapsed_tokens = self.total_tokens - processed_tokens
                        self.logger.info(
                            "Epoch %3d Step: %8d Batch Loss: %12.6f "
                            "Tokens per Sec: %8.0f, Lr: %.6f", epoch_no + 1,
                            self.steps, batch_loss, elapsed_tokens / elapsed,
                            self.optimizer.param_groups[0]["lr"])
                        start = time.time()
                        total_valid_duration = 0

                    # validate on the entire dev set
                    if self.steps % self.validation_freq == 0 and update:

                        if self.manage_decoder_timer:
                            self._log_decoder_timer_stats("train")
                            self.decoder_timer.reset()

                        valid_start_time = time.time()


                        valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                            valid_hypotheses_raw, valid_attention_scores, valid_kb_att_scores, \
                            valid_ent_f1, valid_ent_mcc = \
                            validate_on_data(
                                batch_size=self.eval_batch_size,
                                data=valid_data,
                                eval_metric=self.eval_metric,
                                level=self.level,
                                model=self.model,
                                use_cuda=self.use_cuda,
                                max_output_length=self.max_output_length,
                                loss_function=self.loss,
                                beam_size=0,  # greedy validations #FIXME XXX NOTE TODO BUG set to 0 again!
                                batch_type=self.eval_batch_type,
                                kb_task=kb_task,
                                valid_kb=valid_kb,
                                valid_kb_lkp=valid_kb_lkp,
                                valid_kb_lens=valid_kb_lens,
                                valid_kb_truvals=valid_kb_truvals,
                                valid_data_canon=valid_data_canon,
                                report_on_canonicals=self.report_entf1_on_canonicals
                            )

                        if self.manage_decoder_timer:
                            self._log_decoder_timer_stats("valid")
                            self.decoder_timer.reset()

                        self.tb_writer.add_scalar("valid/valid_loss",
                                                  valid_loss, self.steps)
                        self.tb_writer.add_scalar("valid/valid_score",
                                                  valid_score, self.steps)
                        self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl,
                                                  self.steps)

                        if self.early_stopping_metric == "loss":
                            ckpt_score = valid_loss
                        elif self.early_stopping_metric in [
                                "ppl", "perplexity"
                        ]:
                            ckpt_score = valid_ppl
                        else:
                            ckpt_score = valid_score

                        new_best = False
                        if self.is_best(ckpt_score):
                            self.best_ckpt_score = ckpt_score
                            self.best_ckpt_iteration = self.steps
                            self.logger.info(
                                'Hooray! New best validation result [%s]!',
                                self.early_stopping_metric)
                            if self.ckpt_queue.maxsize > 0:
                                self.logger.info("Saving new checkpoint.")
                                new_best = True
                                self._save_checkpoint()

                        if self.scheduler is not None \
                                and self.scheduler_step_at == "validation":
                            self.scheduler.step(ckpt_score)

                        # append to validation report
                        self._add_report(valid_score=valid_score,
                                         valid_loss=valid_loss,
                                         valid_ppl=valid_ppl,
                                         eval_metric=self.eval_metric,
                                         valid_ent_f1=valid_ent_f1,
                                         valid_ent_mcc=valid_ent_mcc,
                                         new_best=new_best)

                        # pylint: disable=unnecessary-comprehension
                        self._log_examples(
                            sources_raw=[v for v in valid_sources_raw],
                            sources=valid_sources,
                            hypotheses_raw=valid_hypotheses_raw,
                            hypotheses=valid_hypotheses,
                            references=valid_references)

                        valid_duration = time.time() - valid_start_time
                        total_valid_duration += valid_duration
                        self.logger.info(
                            'Validation result at epoch %3d, step %8d: %s: %6.2f, '
                            'loss: %8.4f, ppl: %8.4f, duration: %.4fs',
                            epoch_no + 1, self.steps, self.eval_metric,
                            valid_score, valid_loss, valid_ppl, valid_duration)

                        # store validation set outputs
                        self._store_outputs(valid_hypotheses)

                        valid_src = list(valid_data.src)
                        # store attention plots for selected valid sentences
                        if valid_attention_scores:
                            plot_success_ratio = store_attention_plots(
                                attentions=valid_attention_scores,
                                targets=valid_hypotheses_raw,
                                sources=valid_src,
                                indices=self.log_valid_sents,
                                output_prefix="{}/att.{}".format(
                                    self.model_dir, self.steps),
                                tb_writer=self.tb_writer,
                                steps=self.steps)
                            self.logger.info(
                                f"stored {plot_success_ratio} valid att scores!"
                            )
                        if valid_kb_att_scores:
                            plot_success_ratio = store_attention_plots(
                                attentions=valid_kb_att_scores,
                                targets=valid_hypotheses_raw,
                                sources=list(valid_kb.kbsrc),
                                indices=self.log_valid_sents,
                                output_prefix="{}/kbatt.{}".format(
                                    self.model_dir, self.steps),
                                tb_writer=self.tb_writer,
                                steps=self.steps,
                                kb_info=(valid_kb_lkp, valid_kb_lens,
                                         valid_kb_truvals),
                                on_the_fly_info=(valid_src, valid_kb,
                                                 self.model.canonize,
                                                 self.model.trg_vocab))
                            self.logger.info(
                                f"stored {plot_success_ratio} valid kb att scores!"
                            )
                        else:
                            self.logger.info(
                                "theres no valid kb att scores...")
                    if self.stop:
                        break
                if self.stop:
                    self.logger.info(
                        'Training ended since minimum lr %f was reached.',
                        self.learning_rate_min)
                    break

                self.logger.info('Epoch %3d: total training loss %.2f',
                                 epoch_no + 1, epoch_loss)
            else:
                self.logger.info('Training ended after %3d epochs.',
                                 epoch_no + 1)
            self.logger.info('Best validation result at step %8d: %6.2f %s.',
                             self.best_ckpt_iteration, self.best_ckpt_score,
                             self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer
Пример #16
0
def validate_on_data(model: Model, data: Dataset,
                     batch_size: int,
                     use_cuda: bool, max_output_length: int,
                     src_level: str,
                     trg_level: str,
                     eval_metrics: Optional[Sequence[str]],
                     attn_metrics: Optional[Sequence[str]],
                     loss_function: torch.nn.Module = None,
                     beam_size: int = 0, beam_alpha: int = 0,
                     batch_type: str = "sentence",
                     save_attention: bool = False,
                     log_sparsity: bool = False,
                     apply_mask: bool = True  # hmm
                     ) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `loss_function` is not None and references are given,
    also compute the loss.

    :param model: model module
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param src_level: source segmentation level, one of "char", "bpe", "word"
    :param trg_level: target segmentation level, one of "char", "bpe", "word"
    :param eval_metrics: evaluation metric, e.g. "bleu"
    :param loss_function: loss function that computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If 0 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to 0 (default).
    :param batch_type: validation batch type (sentence or token)

    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
    """
    eval_funcs = {
        "bleu": bleu,
        "chrf": chrf,
        "token_accuracy": partial(token_accuracy, level=trg_level),
        "sequence_accuracy": sequence_accuracy,
        "wer": wer,
        "cer": partial(character_error_rate, level=trg_level)
    }
    selected_eval_metrics = {name: eval_funcs[name] for name in eval_metrics}

    valid_iter = make_data_iter(
        dataset=data, batch_size=batch_size, batch_type=batch_type,
        shuffle=False, train=False)
    valid_sources_raw = [s for s in data.src]
    pad_index = model.src_vocab.stoi[PAD_TOKEN]
    # disable dropout
    model.eval()
    # don't track gradients during validation
    scorer = partial(len_penalty, alpha=beam_alpha) if beam_alpha > 0 else None
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = defaultdict(list)
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        total_attended = defaultdict(int)
        greedy_steps = 0
        greedy_supported = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = Batch(valid_batch, pad_index, use_cuda=use_cuda)
            # sort batch now by src length and keep track of order
            sort_reverse_index = batch.sort_by_src_lengths()

            # run as during training with teacher forcing
            if loss_function is not None and batch.trg is not None:
                batch_loss = model.get_loss_for_batch(
                    batch, loss_function=loss_function)
                total_loss += batch_loss
                total_ntokens += batch.ntokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores, probs = model.run_batch(
                batch=batch, beam_size=beam_size, scorer=scorer,
                max_output_length=max_output_length, log_sparsity=log_sparsity,
                apply_mask=apply_mask)
            if log_sparsity:
                lengths = torch.LongTensor((output == model.trg_vocab.stoi[EOS_TOKEN]).argmax(axis=1)).unsqueeze(1)
                batch_greedy_steps = lengths.sum().item()
                greedy_steps += lengths.sum().item()

                ix = torch.arange(output.shape[1]).unsqueeze(0).expand(output.shape[0], -1)
                mask = ix <= lengths
                supp = probs.exp().gt(0).sum(dim=-1).cpu()  # batch x len
                supp = torch.where(mask, supp, torch.tensor(0)).sum()
                greedy_supported += supp.float().item()

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])

            if attention_scores is not None:
                # is attention_scores ever None?
                if save_attention:
                    # beam search currently does not support attention logging
                    for k, v in attention_scores.items():
                        valid_attention_scores[k].extend(v[sort_reverse_index])
                if attn_metrics:
                    # add to total_attended
                    for k, v in attention_scores.items():
                        total_attended[k] += (v > 0).sum()

        assert len(all_outputs) == len(data)

        if log_sparsity:
            print(greedy_supported / greedy_steps)

        valid_scores = dict()
        if loss_function is not None and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            valid_scores["loss"] = total_loss
            valid_scores["ppl"] = torch.exp(total_loss / total_ntokens)

        # decode back to symbols
        decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs,
                                                            cut_at_eos=True)

        # evaluate with metric on full dataset
        src_join_char = " " if src_level in ["word", "bpe"] else ""
        trg_join_char = " " if trg_level in ["word", "bpe"] else ""
        valid_sources = [src_join_char.join(s) for s in data.src]
        valid_references = [trg_join_char.join(t) for t in data.trg]
        valid_hypotheses = [trg_join_char.join(t) for t in decoded_valid]

        if attn_metrics:
            decoded_ntokens = sum(len(t) for t in decoded_valid)
            for attn_metric in attn_metrics:
                assert attn_metric == "support"
                for attn_name, tot_attended in total_attended.items():
                    score_name = attn_name + "_" + attn_metric
                    # this is not the right denominator
                    valid_scores[score_name] = tot_attended / decoded_ntokens

        # post-process
        if src_level == "bpe":
            valid_sources = [bpe_postprocess(s) for s in valid_sources]
        if trg_level == "bpe":
            valid_references = [bpe_postprocess(v) for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses]

        languages = [language for language in data.language]
        by_language = defaultdict(list)
        seqs = zip(valid_references, valid_hypotheses) if valid_references else valid_hypotheses
        if languages:
            examples = zip(languages, seqs)
            for lang, seq in examples:
                by_language[lang].append(seq)
        else:
            by_language[None].extend(seqs)

        # if references are given, evaluate against them
        # incorrect if-condition?
        # scores_by_lang = {name: dict() for name in selected_eval_metrics}
        scores_by_lang = dict()
        if valid_references and eval_metrics is not None:
            assert len(valid_hypotheses) == len(valid_references)

            for eval_metric, eval_func in selected_eval_metrics.items():
                score_by_lang = dict()
                for lang, pairs in by_language.items():
                    lang_hyps, lang_refs = zip(*pairs)
                    lang_score = eval_func(lang_hyps, lang_refs)
                    score_by_lang[lang] = lang_score

                score = sum(score_by_lang.values()) / len(score_by_lang)
                valid_scores[eval_metric] = score
                scores_by_lang[eval_metric] = score_by_lang

    if not languages:
        scores_by_lang = None
    return valid_scores, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores, scores_by_lang, by_language
Пример #17
0
def validate_on_data(model: Model,
                     data: Dataset,
                     batch_size: int,
                     use_cuda: bool,
                     max_output_length: int,
                     level: str,
                     eval_metric: Optional[str],
                     loss_function: torch.nn.Module = None,
                     beam_size: int = 0,
                     beam_alpha: int = -1,
                     batch_type: str = "sentence",
                     kb_task = None,
                     valid_kb: Dataset = None,
                     valid_kb_lkp: list = [],
                     valid_kb_lens:list=[],
                     valid_kb_truvals: Dataset = None,
                     valid_data_canon: Dataset = None,
                     report_on_canonicals: bool = False,
                     ) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `loss_function` is not None and references are given,
    also compute the loss.

    :param model: model module
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param level: segmentation level, one of "char", "bpe", "word"
    :param eval_metric: evaluation metric, e.g. "bleu"
    :param loss_function: loss function that computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If 0 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to -1 (default).
    :param batch_type: validation batch type (sentence or token)
    :param kb_task: is not None if kb_task should be executed
    :param valid_kb: MonoDataset holding the loaded valid kb data
    :param valid_kb_lkp: List with valid example index to corresponding kb indices
    :param valid_kb_len: List with amount of triples per kb 
    :param valid_data_canon: TranslationDataset of valid data but with canonized target data (for loss reporting)


    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
        - valid_ent_f1: TODO FIXME
    """

    print(f"\n{'-'*10} ENTER VALIDATION {'-'*10}\n")

    print(f"\n{'-'*10}  VALIDATION DEBUG {'-'*10}\n")

    print("---data---")
    print(dir(data[0]))
    print([[
        getattr(example, attr) for attr in dir(example)
        if hasattr(getattr(example, attr), "__iter__") and "kb" in attr
        or "src" in attr or "trg" in attr
    ] for example in data[:3]])
    print(batch_size)
    print(use_cuda)
    print(max_output_length)
    print(level)
    print(eval_metric)
    print(loss_function)
    print(beam_size)
    print(beam_alpha)
    print(batch_type)
    print(kb_task)
    print("---valid_kb---")
    print(dir(valid_kb[0]))
    print([[
        getattr(example, attr) for attr in dir(example)
        if hasattr(getattr(example, attr), "__iter__") and "kb" in attr
        or "src" in attr or "trg" in attr
    ] for example in valid_kb[:3]])
    print(len(valid_kb_lkp), valid_kb_lkp[-5:])
    print(len(valid_kb_lens), valid_kb_lens[-5:])
    print("---valid_kb_truvals---")
    print(len(valid_kb_truvals), valid_kb_lens[-5:])
    print([[
        getattr(example, attr) for attr in dir(example)
        if hasattr(getattr(example, attr), "__iter__") and "kb" in attr
        or "src" in attr or "trg" in attr or "trv" in attr
    ] for example in valid_kb_truvals[:3]])
    print("---valid_data_canon---")
    print(len(valid_data_canon), valid_data_canon[-5:])
    print([[
        getattr(example, attr) for attr in dir(example)
        if hasattr(getattr(example, attr), "__iter__") and "kb" in attr
        or "src" in attr or "trg" in attr or "trv" or "can" in attr
    ] for example in valid_data_canon[:3]])
    print(report_on_canonicals)

    print(f"\n{'-'*10} END VALIDATION DEBUG {'-'*10}\n")

    if not kb_task:
        valid_iter = make_data_iter(dataset=data,
                                    batch_size=batch_size,
                                    batch_type=batch_type,
                                    shuffle=False,
                                    train=False)
    else:
        # knowledgebase version of make data iter and also provide canonized target data
        # data: for bleu/ent f1
        # canon_data: for loss
        valid_iter = make_data_iter_kb(data,
                                       valid_kb,
                                       valid_kb_lkp,
                                       valid_kb_lens,
                                       valid_kb_truvals,
                                       batch_size=batch_size,
                                       batch_type=batch_type,
                                       shuffle=False,
                                       train=False,
                                       canonize=model.canonize,
                                       canon_data=valid_data_canon)

    valid_sources_raw = data.src
    pad_index = model.src_vocab.stoi[PAD_TOKEN]

    # disable dropout
    model.eval()
    # don't track gradients during validation
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = []
        valid_kb_att_scores = []
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = Batch(valid_batch, pad_index, use_cuda=use_cuda) \
                                if not kb_task else \
                Batch_with_KB(valid_batch, pad_index, use_cuda=use_cuda)

            assert hasattr(batch, "kbsrc") == bool(kb_task)

            # sort batch now by src length and keep track of order
            if not kb_task:
                sort_reverse_index = batch.sort_by_src_lengths()
            else:
                sort_reverse_index = list(range(batch.src.shape[0]))

            # run as during training with teacher forcing
            if loss_function is not None and batch.trg is not None:

                ntokens = batch.ntokens
                if hasattr(batch, "trgcanon") and batch.trgcanon is not None:
                    ntokens = batch.ntokenscanon  # normalize loss with num canonical tokens for perplexity
                # do a loss calculation without grad updates just to report valid loss
                # we can only do this when batch.trg exists, so not during actual translation/deployment
                batch_loss = model.get_loss_for_batch(
                    batch, loss_function=loss_function)
                # keep track of metrics for reporting
                total_loss += batch_loss
                total_ntokens += ntokens  # gold target tokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores, kb_att_scores = model.run_batch(
                batch=batch,
                beam_size=beam_size,
                beam_alpha=beam_alpha,
                max_output_length=max_output_length)

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])
            valid_attention_scores.extend(
                attention_scores[sort_reverse_index]
                if attention_scores is not None else [])
            valid_kb_att_scores.extend(kb_att_scores[sort_reverse_index]
                                       if kb_att_scores is not None else [])

        assert len(all_outputs) == len(data)

        if loss_function is not None and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            # exponent of token-level negative log likelihood
            # can be seen as 2^(cross_entropy of model on valid set); normalized by num tokens;
            # see https://en.wikipedia.org/wiki/Perplexity#Perplexity_per_word
            valid_ppl = torch.exp(valid_loss / total_ntokens)
        else:
            valid_loss = -1
            valid_ppl = -1

        # decode back to symbols
        decoding_vocab = model.trg_vocab if not kb_task else model.trv_vocab

        decoded_valid = decoding_vocab.arrays_to_sentences(arrays=all_outputs,
                                                           cut_at_eos=True)

        print(f"decoding_vocab.itos: {decoding_vocab.itos}")
        print(decoded_valid)

        # evaluate with metric on full dataset
        join_char = " " if level in ["word", "bpe"] else ""
        valid_sources = [join_char.join(s) for s in data.src]
        # TODO replace valid_references with uncanonicalized dev.car data ... requires writing new Dataset in data.py
        valid_references = [join_char.join(t) for t in data.trg]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid]

        # post-process
        if level == "bpe":
            valid_sources = [bpe_postprocess(s) for s in valid_sources]
            valid_references = [bpe_postprocess(v) for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses]

        # if references are given, evaluate against them
        if valid_references:
            assert len(valid_hypotheses) == len(valid_references)

            print(list(zip(valid_sources, valid_references, valid_hypotheses)))

            current_valid_score = 0
            if eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'chrf':
                current_valid_score = chrf(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(valid_hypotheses,
                                                     valid_references,
                                                     level=level)
            elif eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)

            if kb_task:
                valid_ent_f1, valid_ent_mcc = calc_ent_f1_and_ent_mcc(
                    valid_hypotheses,
                    valid_references,
                    vocab=model.trv_vocab,
                    c_fun=model.canonize,
                    report_on_canonicals=report_on_canonicals)

            else:
                valid_ent_f1, valid_ent_mcc = -1, -1
        else:
            current_valid_score = -1

    print(f"\n{'-'*10} EXIT VALIDATION {'-'*10}\n")
    return current_valid_score, valid_loss, valid_ppl, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores, valid_kb_att_scores, \
        valid_ent_f1, valid_ent_mcc
Пример #18
0
    def dev_network(self):
        """
        Show how is the current performace over the dev dataset, by mean of the
        total reward and the belu score.
        
        :return: current Bleu score
        """
        freeze_model(self.eval_net)
        for data_set_name, data_set in self.data_to_dev.items():
            #print(data_set_name)
            valid_iter = make_data_iter(dataset=data_set,
                                        batch_size=1,
                                        batch_type=self.batch_type,
                                        shuffle=False,
                                        train=False)
            valid_sources_raw = data_set.src

            # don't track gradients during validation
            r_total = 0
            roptimal_total = 0
            all_outputs = []
            i_sample = 0

            for valid_batch in iter(valid_iter):
                # run as during training to get validation loss (e.g. xent)

                batch = Batch(valid_batch,
                              self.pad_index,
                              use_cuda=self.use_cuda)

                encoder_output, encoder_hidden = self.model.encode(
                    batch.src, batch.src_lengths, batch.src_mask)

                # if maximum output length is
                # not globally specified, adapt to src len
                if self.max_output_length is None:
                    self.max_output_length = int(
                        max(batch.src_lengths.cpu().numpy()) * 1.5)

                batch_size = batch.src_mask.size(0)
                prev_y = batch.src_mask.new_full(size=[batch_size, 1],
                                                 fill_value=self.bos_index,
                                                 dtype=torch.long)
                output = []
                hidden = self.model.decoder._init_hidden(encoder_hidden)
                prev_att_vector = None
                finished = batch.src_mask.new_zeros((batch_size, 1)).byte()

                # pylint: disable=unused-variable
                for t in range(self.max_output_length):

                    # if i_sample == 0 or i_sample == 3 or i_sample == 6:
                    #     print("state on t = ", t, " : " , state)

                    # decode one single step
                    logits, hidden, att_probs, prev_att_vector = self.model.decoder(
                        encoder_output=encoder_output,
                        encoder_hidden=encoder_hidden,
                        src_mask=batch.src_mask,
                        trg_embed=self.model.trg_embed(prev_y),
                        hidden=hidden,
                        prev_att_vector=prev_att_vector,
                        unroll_steps=1)
                    # greedy decoding: choose arg max over vocabulary in each step with egreedy porbability

                    if self.state_type == 'hidden':
                        state = torch.cat(hidden,
                                          dim=2).squeeze(1).detach().cpu()[0]
                    else:
                        state = torch.FloatTensor(
                            prev_att_vector.squeeze(1).detach().cpu().numpy()
                            [0])

                    logits = self.eval_net(state)
                    logits = logits.reshape([1, 1, -1])
                    #print(type(logits), logits.shape, logits)
                    next_word = torch.argmax(logits, dim=-1)
                    a = next_word.squeeze(1).detach().cpu().numpy()[0]
                    prev_y = next_word

                    output.append(next_word.squeeze(1).detach().cpu().numpy())
                    prev_y = next_word

                    # check if previous symbol was <eos>
                    is_eos = torch.eq(next_word, self.eos_index)
                    finished += is_eos
                    # stop predicting if <eos> reached for all elements in batch
                    if (finished >= 1).sum() == batch_size:
                        break
                stacked_output = np.stack(output, axis=1)  # batch, time

                #decode back to symbols
                decoded_valid_in = self.model.trg_vocab.arrays_to_sentences(
                    arrays=batch.src, cut_at_eos=True)
                decoded_valid_out_trg = self.model.trg_vocab.arrays_to_sentences(
                    arrays=batch.trg, cut_at_eos=True)
                decoded_valid_out = self.model.trg_vocab.arrays_to_sentences(
                    arrays=stacked_output, cut_at_eos=True)

                hyp = stacked_output

                r = self.Reward(batch.trg, hyp, show=False)

                if i_sample == 0 or i_sample == 3 or i_sample == 6:
                    print(
                        "\n Sample ", i_sample,
                        "-------------Target vs Eval_net prediction:--Raw---and---Decoded-----"
                    )
                    print("Target: ", batch.trg, decoded_valid_out_trg)
                    print("Eval  : ", stacked_output, decoded_valid_out, "\n")
                    print("Reward: ", r)

                #r = self.Reward1(batch.trg, hyp , show = False)
                r_total += sum(r[np.where(r > 0)])
                if i_sample == 0:
                    roptimal = self.Reward(batch.trg, batch.trg, show=False)
                    roptimal_total += sum(roptimal[np.where(roptimal > 0)])

                all_outputs.extend(stacked_output)
                i_sample += 1

            assert len(all_outputs) == len(data_set)

            # decode back to symbols
            decoded_valid = self.model.trg_vocab.arrays_to_sentences(
                arrays=all_outputs, cut_at_eos=True)

            # evaluate with metric on full dataset
            join_char = " " if self.level in ["word", "bpe"] else ""
            valid_sources = [join_char.join(s) for s in data_set.src]
            valid_references = [join_char.join(t) for t in data_set.trg]
            valid_hypotheses = [join_char.join(t) for t in decoded_valid]

            # post-process
            if self.level == "bpe":
                valid_sources = [bpe_postprocess(s) for s in valid_sources]
                valid_references = [
                    bpe_postprocess(v) for v in valid_references
                ]
                valid_hypotheses = [
                    bpe_postprocess(v) for v in valid_hypotheses
                ]

            # if references are given, evaluate against them
            if valid_references:
                assert len(valid_hypotheses) == len(valid_references)

                current_valid_score = 0
                if self.eval_metric.lower() == 'bleu':
                    # this version does not use any tokenization
                    current_valid_score = bleu(valid_hypotheses,
                                               valid_references)
                elif self.eval_metric.lower() == 'chrf':
                    current_valid_score = chrf(valid_hypotheses,
                                               valid_references)
                elif self.eval_metric.lower() == 'token_accuracy':
                    current_valid_score = token_accuracy(valid_hypotheses,
                                                         valid_references,
                                                         level=self.level)
                elif self.eval_metric.lower() == 'sequence_accuracy':
                    current_valid_score = sequence_accuracy(
                        valid_hypotheses, valid_references)
            else:
                current_valid_score = -1

            self.dev_network_count += 1
            self.tb_writer.add_scalar("dev/dev_reward", r_total,
                                      self.dev_network_count)
            self.tb_writer.add_scalar("dev/dev_bleu", current_valid_score,
                                      self.dev_network_count)

            print(self.dev_network_count, ' r_total and score: ', r_total,
                  current_valid_score)

            unfreeze_model(self.eval_net)
        return current_valid_score
Пример #19
0
    def Collecting_experiences(self) -> None:
        """
        Main funtion. Compute all the process.

        :param exp_list: List of experineces. Tuples (memory_counter, state, a, state_, is_eos[0,0])
        :param rew: rewards for every experince. Of lenght of the hypotesis
        """
        for epoch_no in range(self.epochs):
            print("EPOCH %d", epoch_no + 1)

            #beam_dqn = self.beam_min + int(self.beam_max * epoch_no/self.epochs)
            #egreed = self.egreed_max*(1 - epoch_no/(1.1*self.epochs))
            #self.gamma = self.gamma_max*(1 - epoch_no/(2*self.epochs))

            beam_dqn = 1
            egreed = 0.5
            #self.gamma = self.gamma_max
            self.gamma = 0.6

            self.tb_writer.add_scalar("parameters/beam_dqn", beam_dqn,
                                      epoch_no)
            self.tb_writer.add_scalar("parameters/egreed", egreed, epoch_no)
            self.tb_writer.add_scalar("parameters/gamma", self.gamma, epoch_no)
            if beam_dqn > self.actions_size:
                print("The beam_dqn cannot exceed the action size!")
                print("then the beam_dqn = action size")
                beam_dqn = self.actions_size

            print(' beam_dqn, egreed, gamma: ', beam_dqn, egreed, self.gamma)
            for _, data_set in self.data_to_train_dqn.items():

                valid_iter = make_data_iter(dataset=data_set,
                                            batch_size=1,
                                            batch_type=self.batch_type,
                                            shuffle=False,
                                            train=False)
                #valid_sources_raw = data_set.src
                # disable dropout
                #self.model.eval()

                i_sample = 0
                for valid_batch in iter(valid_iter):
                    freeze_model(self.model)
                    batch = Batch(valid_batch,
                                  self.pad_index,
                                  use_cuda=self.use_cuda)

                    encoder_output, encoder_hidden = self.model.encode(
                        batch.src, batch.src_lengths, batch.src_mask)
                    # if maximum output length is not globally specified, adapt to src len

                    if self.max_output_length is None:
                        self.max_output_length = int(
                            max(batch.src_lengths.cpu().numpy()) * 1.5)

                    batch_size = batch.src_mask.size(0)
                    prev_y = batch.src_mask.new_full(size=[batch_size, 1],
                                                     fill_value=self.bos_index,
                                                     dtype=torch.long)
                    output = []
                    hidden = self.model.decoder._init_hidden(encoder_hidden)
                    prev_att_vector = None
                    finished = batch.src_mask.new_zeros((batch_size, 1)).byte()

                    # print("Source_raw: ", batch.src)
                    # print("Target_raw: ", batch.trg_input)
                    # print("y0: ", prev_y)

                    exp_list = []
                    # pylint: disable=unused-variable
                    for t in range(self.max_output_length):
                        if t != 0:
                            if self.state_type == 'hidden':
                                state = torch.cat(
                                    hidden,
                                    dim=2).squeeze(1).detach().cpu().numpy()[0]
                            else:
                                if t == 0:
                                    state = hidden[0].squeeze(
                                        1).detach().cpu().numpy()[0]
                                else:
                                    state = prev_att_vector.squeeze(
                                        1).detach().cpu().numpy()[0]

                        # decode one single step
                        logits, hidden, att_probs, prev_att_vector = self.model.decoder(
                            encoder_output=encoder_output,
                            encoder_hidden=encoder_hidden,
                            src_mask=batch.src_mask,
                            trg_embed=self.model.trg_embed(prev_y),
                            hidden=hidden,
                            prev_att_vector=prev_att_vector,
                            unroll_steps=1)
                        # logits: batch x time=1 x vocab (logits)
                        if t != 0:
                            if self.state_type == 'hidden':
                                state_ = torch.cat(
                                    hidden,
                                    dim=2).squeeze(1).detach().cpu().numpy()[0]
                            else:
                                state_ = prev_att_vector.squeeze(
                                    1).detach().cpu().numpy()[0]

                        # if t == 0:
                        #     print('states0: ', state, state_)

                        # greedy decoding: choose arg max over vocabulary in each step with egreedy porbability

                        if random.uniform(0, 1) < egreed:
                            i_ran = random.randint(0, beam_dqn - 1)
                            next_word = torch.argsort(logits,
                                                      descending=True)[:, :,
                                                                       i_ran]
                        else:
                            next_word = torch.argmax(logits,
                                                     dim=-1)  # batch x time=1
                        # if t != 0:
                        a = prev_y.squeeze(1).detach().cpu().numpy()[0]
                        #a = next_word.squeeze(1).detach().cpu().numpy()[0]

                        # print("state ",t," : ", state )
                        # print("state_ ",t," : ", state_ )
                        # print("action ",t," : ", a )
                        # print("__________________________________________")

                        output.append(
                            next_word.squeeze(1).detach().cpu().numpy())

                        #tup = (self.memory_counter, state, a, state_)

                        prev_y = next_word
                        # check if previous symbol was <eos>
                        is_eos = torch.eq(next_word, self.eos_index)
                        finished += is_eos
                        if t != 0:
                            self.memory_counter += 1
                            tup = (self.memory_counter, state, a, state_, 1)
                            exp_list.append(tup)

                        #print(t)
                        # stop predicting if <eos> reached for all elements in batch
                        if (finished >= 1).sum() == batch_size:
                            a = next_word.squeeze(1).detach().cpu().numpy()[0]
                            self.memory_counter += 1
                            #tup = (self.memory_counter, state_, a, np.zeros([self.state_size]) , is_eos[0,0])
                            tup = (self.memory_counter, state_, a,
                                   np.zeros([self.state_size]), 0)
                            exp_list.append(tup)
                            #print('break')
                            break
                        if t == self.max_output_length - 1:
                            #print("reach the max output")
                            a = 0
                            self.memory_counter += 1
                            #tup = (self.memory_counter, state_, a, np.zeros([self.state_size]) , is_eos[0,0])
                            tup = (self.memory_counter, state_, a,
                                   -1 * np.ones([self.state_size]), 1)
                            exp_list.append(tup)

                    #Collecting rewards
                    hyp = np.stack(output, axis=1)  # batch, time

                    if epoch_no == 0:
                        if i_sample == 0 or i_sample == 3 or i_sample == 6:
                            #print(i_sample)
                            r = self.Reward(batch.trg, hyp,
                                            show=True)  # 1 , time-1
                        else:
                            r = self.Reward(batch.trg, hyp,
                                            show=False)  # 1 , time -1
                    else:
                        #print("aaaa - ",i_sample)
                        r = self.Reward(batch.trg, hyp,
                                        show=False)  # 1 , time -1

                    # if i_sample == 0 or i_sample == 3 or i_sample == 6:
                    #     print("\n Sample Collected: ", i_sample, "-------------Target vs Eval_net prediction:--Raw---and---Decoded-----")
                    #     print("Target: ", batch.trg, decoded_valid_out_trg)
                    #     print("Eval  : ", stacked_output, decoded_valid_out)
                    #     print("Reward: ", r, "\n")

                    i_sample += 1
                    self.store_transition(exp_list, r)

                    #Learning.....
                    if self.memory_counter > self.mem_cap - self.max_output_length:
                        self.learn()

        self.tb_writer.close()