Beispiel #1
0
    def validate(self, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """

        self.model.eval()
        total_stats = Statistics()
        self.model.zero_grad()
        for batch in valid_iter:
            kld_loss = 0.
            normalization = batch.batch_size

            src, src_lengths = batch.src, batch.src_Ls
            trg, trg_lengths = batch.trg, batch.trg_Ls
            ref = batch.trg[1:]
            # F-prop through the model.
            if isinstance(self.model, VRNMTModel):
                outputs, _, _, kld_loss = self.model(src, src_lengths, trg)
            elif isinstance(self.model, VNMTModel):
                outputs, _, _, kld_loss = self.model(src, src_lengths, trg)
            elif isinstance(self.model, NMTModel):
                outputs, _, _ = self.model(src, src_lengths, trg)

            probs = self.model.generator(outputs.view(-1, outputs.size(2)))
            loss, batch_stats = self.loss_func.compute_batch_loss(
                probs, ref, normalization, kld_loss=kld_loss)

            # # Update statistics.
            total_stats.update(batch_stats)
            del outputs, probs, batch_stats, loss
            # # Set model back to training mode.
        return total_stats
Beispiel #2
0
    def train(self, current_epoch, epochs, train_data, valid_data,
              num_batches):
        """ Train next epoch.
        Args:
            train_data (BatchDataIterator): training dataset iterator
            valid_data (BatchDataIterator): validation dataset iterator
            epoch (int): the epoch number
            num_batches (int): the batch number
        Returns:
            stats (Statistics): epoch loss statistics
        """
        self.model.train()

        if self.stop:
            return
        header = '-' * 30 + "Epoch [%d]" + '-' * 30
        trace(header % current_epoch)
        train_stats = Statistics()
        num_batches = train_data.num_batches

        batch_cache = []
        for idx, batch in enumerate(iter(train_data), 1):
            batch_cache.append(batch)
            if len(batch_cache) == self.accum_grad_count or idx == num_batches:
                stats = self.train_each_batch(batch_cache, current_epoch, idx,
                                              num_batches)
                batch_cache = []
                if idx == train_data.num_batches:
                    train_stats.update(stats)
                if idx % self.report_every == 0 or idx == num_batches:
                    trace(
                        stats.report(current_epoch, idx, num_batches,
                                     self.optim.lr))
            if idx % (self.report_every * 10) == 0 and self.early_stop:
                valid_stats = self.validate(valid_data)
                trace("Validation: " + valid_stats.report(
                    current_epoch, idx, num_batches, self.optim.lr))
                if self.early_stop(valid_stats.ppl()):
                    self.stop = True
                    break
        valid_stats = self.validate(valid_data)
        trace(str(valid_stats))
        suffix = ".acc{0:.2f}.ppl{1:.2f}.e{2:d}".format(
            valid_stats.accuracy(), valid_stats.ppl(), current_epoch)
        self.optim.update_lr(valid_stats.ppl(), current_epoch)
        dump_checkpoint(self.model, self.save_model, suffix)
Beispiel #3
0
    def train(self, train_iter, epoch, num_batches):
        """ Train next epoch.
        Args:
            train_iter (BatchDataIterator): training data iterator
            epoch (int): the epoch number
            num_batches (int): the batch number
        Returns:
            stats (Statistics): epoch loss statistics
        """
        self.model.train()

        total_stats = Statistics()
        self.loss_func.kld_weight_step(epoch,
                                       self.config.start_increase_kld_at)
        for idx, batch in enumerate(train_iter):
            self.model.zero_grad()
            src, src_lengths = batch.src, batch.src_Ls
            trg, trg_lengths = batch.trg, batch.trg_Ls
            ref = batch.trg[1:]
            kld_loss = 0.
            normalization = batch.batch_size
            if isinstance(self.model, VRNMTModel):
                outputs, _, _, kld_loss = self.model(src, src_lengths, trg)
            elif isinstance(self.model, VNMTModel):
                outputs, _, _, kld_loss = self.model(src, src_lengths, trg)
            elif isinstance(self.model, NMTModel):
                outputs, _, _ = self.model(src, src_lengths, trg)

            probs = self.model.generator(outputs.view(-1, outputs.size(2)))

            loss, batch_stats = self.loss_func.compute_batch_loss(
                probs, ref, normalization, kld_loss=kld_loss)

            loss.backward()
            # 4. Update the parameters and statistics.
            self.optim.step()

            del loss, outputs, probs

            report_stats(batch_stats, epoch, idx + 1, num_batches,
                         self.progress_step, self.optim.lr)

            total_stats.update(batch_stats)
            self.progress_step += 1

        return total_stats
Beispiel #4
0
 def validate(self, valid_data):
     """ Validate model.
         valid_iter: validate data iterator
     Returns:
         :obj:`nmt.Statistics`: validation loss statistics
     """
     self.model.eval()
     valid_stats = Statistics()
     for batch in iter(valid_data):
         normalization = batch.batch_size
         src, src_lengths = batch.src, batch.src_Ls
         trg, ref = batch.trg[:-1], batch.trg[1:]
         outputs = self.model(src, src_lengths, trg)[0]
         probs = self.model.generator(outputs)
         loss, stats = self.valid_loss.compute(probs, ref, normalization)
         valid_stats.update(stats)
         del outputs, probs, stats, loss
     self.model.train()
     return valid_stats
Beispiel #5
0
    def create_stats(self, loss, probs, golds, loss_dict):
        """
        Args:
            loss (`FloatTensor`): the loss computed by the loss criterion.
            scores (`FloatTensor`): a score for each possible output
            target (`FloatTensor`): true targets

        Returns:
            `Statistics` : statistics for this batch.
        """
        preds = probs.data.topk(1, dim=-1)[1]
        non_padding = golds.ne(self.padding_idx)
        correct = preds.squeeze().eq(golds).masked_select(non_padding)
        num_words = non_padding.long().sum()
        num_correct = correct.long().sum()
        return Statistics(float(loss), int(num_words), int(num_correct),
                          loss_dict)
Beispiel #6
0
    def train_each_batch(self, batch_cache, current_epoch, idx, num_batches):
        self.model.zero_grad()
        batch_stats = Statistics()
        normalization = 0

        while batch_cache:
            kld = 0
            batch = batch_cache.pop(0)
            src, src_length = batch.src, batch.src_Ls
            trg, ref = batch.trg[:-1], batch.trg[1:]
            normalization += batch.batch_size
            args = self.model(src, src_length, trg)
            outputs = args[0]
            # kld = args[-1]
            probs = self.model.generator(outputs)
            loss, stats = self.train_loss.compute(probs, ref, normalization)
            loss.backward(retain_graph=True)
            batch_stats.update(stats)
            del probs, outputs, loss
        self.optim.step()

        batch_stats.report_and_flush(current_epoch, idx, num_batches,
                                     self.optim.lr)
        return batch_stats