Example #1
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:

                src = batch.src
                labels = batch.labels
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask
                mask_cls = batch.mask_cls

                sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)

                loss = self.loss(sent_scores, labels.float())
                loss = (loss * mask.float()).sum()
                batch_stats = Statistics(float(loss.cpu().data.numpy()),
                                         len(labels))
                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats
Example #2
0
    def report_training(self,
                        step,
                        num_steps,
                        learning_rate,
                        report_stats,
                        multigpu=False):
        """
        This is the user-defined batch-level traing progress
        report function.

        Args:
            step(int): current step count.
            num_steps(int): total number of batches.
            learning_rate(float): current learning rate.
            report_stats(Statistics): old Statistics instance.
        Returns:
            report_stats(Statistics): updated Statistics instance.
        """
        if self.start_time < 0:
            raise ValueError("""ReportMgr needs to be started
                                (set 'start_time' or use 'start()'""")

        if step % self.report_every == 0:
            if multigpu:
                report_stats = \
                    Statistics.all_gather_stats(report_stats)
            self._report_training(step, num_steps, learning_rate, report_stats)
            self.progress_step += 1
            return Statistics()
        else:
            return report_stats
Example #3
0
    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            if self.grad_accum_count == 1:
                self.model.zero_grad()

            src = batch.src
            labels = batch.labels
            segs = batch.segs
            clss = batch.clss
            mask = batch.mask
            mask_cls = batch.mask_cls

            sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)

            loss = self.loss(sent_scores, labels.float())
            loss = (loss * mask.float()).sum()
            (loss / loss.numel()).backward()

            batch_stats = Statistics(float(loss.cpu().data.numpy()),
                                     normalization)

            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

            if self.grad_accum_count == 1:
                self.optim.step()

        if self.grad_accum_count > 1:
            self.optim.step()
Example #4
0
    def _report_training(self, step, num_steps, learning_rate, report_stats):
        """
        See base class method `ReportMgrBase.report_training`.
        """
        report_stats.output(step, num_steps, learning_rate, self.start_time)

        # Log the progress using the number of batches on the x-axis.
        self.maybe_log_tensorboard(report_stats, "progress", learning_rate,
                                   self.progress_step)
        report_stats = Statistics()

        return report_stats
Example #5
0
    def train(self, train_iter_fct, train_steps):
        step = self.optim._step + 1
        true_batchs = []
        accum = 0
        normalization = 0
        train_iter = train_iter_fct()

        total_stats = Statistics()
        report_stats = Statistics()
        self._start_report_manager(start_time=total_stats.start_time)

        while step <= train_steps:
            reduce_counter = 0
            batch = next(train_iter)

            true_batchs.append(batch)
            normalization += batch.batch_size
            accum += 1
            if accum == self.grad_accum_count:
                reduce_counter += 1

                self._gradient_accumulation(true_batchs, normalization,
                                            total_stats, report_stats)

                report_stats = self._report_training(step, train_steps,
                                                     self.optim.learning_rate,
                                                     report_stats)

                true_batchs = []
                accum = 0
                normalization = 0
                if step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0:
                    self._save(step)

                step += 1
                if step > train_steps:
                    break
            train_iter = train_iter_fct()

        return total_stats
Example #6
0
    def _maybe_gather_stats(self, stat):
        """
        Gather statistics in multi-processes cases

        Args:
            stat(:obj:onmt.utils.Statistics): a Statistics object to gather
                or None (it returns None in this case)

        Returns:
            stat: the updated (or unchanged) stat object
        """
        if stat is not None and self.n_gpu > 1:
            return Statistics.all_gather_stats(stat)
        return stat
Example #7
0
    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            if self.grad_accum_count == 1:
                self.model.zero_grad()

            src = batch.src
            labels = batch.labels
            segs = batch.segs
            clss = batch.clss
            mask = batch.mask
            mask_cls = batch.mask_cls

            sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)

            loss = self.loss(sent_scores, labels.float())
            loss = (loss * mask.float()).sum()
            (loss / loss.numel()).backward()

            batch_stats = Statistics(float(loss.cpu().data.numpy()),
                                     normalization)

            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

            # 4. Update the parameters and statistics.
            if self.grad_accum_count == 1:
                # Multi GPU gradient gather
                if self.n_gpu > 1:
                    grads = [
                        p.grad.data for p in self.model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    distributed.all_reduce_and_rescale_tensors(grads, float(1))
                self.optim.step()

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.grad_accum_count > 1:
            if self.n_gpu > 1:
                grads = [
                    p.grad.data for p in self.model.parameters()
                    if p.requires_grad and p.grad is not None
                ]
                distributed.all_reduce_and_rescale_tensors(grads, float(1))
            self.optim.step()
Example #8
0
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """

        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s)) > 0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        can_path = os.path.join(self.args.result_path,
                                "_step{}.candidate".format(step))
        gold_path = os.path.join(self.args.result_path,
                                 "_step{}.gold".format(step))

        with open(can_path, "w") as save_pred:
            with open(gold_path, "w") as save_gold:
                with torch.no_grad():
                    with tqdm(test_iter) as pbar:
                        for batch in pbar:
                            src = batch.src
                            labels = batch.labels
                            segs = batch.segs
                            clss = batch.clss
                            mask = batch.mask
                            mask_cls = batch.mask_cls

                            gold, pred = [], []

                            if (cal_lead):
                                selected_ids = [
                                    list(range(batch.clss.size(1)))
                                ] * batch.batch_size
                            elif (cal_oracle):
                                selected_ids = [[
                                    j for j in range(batch.clss.size(1))
                                    if labels[i][j] == 1
                                ] for i in range(batch.batch_size)]
                            else:
                                sent_scores, mask = self.model(
                                    src, segs, clss, mask, mask_cls)

                                loss = self.loss(sent_scores, labels.float())
                                loss = (loss * mask.float()).sum()
                                batch_stats = Statistics(
                                    float(loss.cpu().data.numpy()),
                                    len(labels))
                                stats.update(batch_stats)

                                sent_scores = sent_scores + mask.float()
                                sent_scores = sent_scores.cpu().data.numpy()
                                selected_ids = np.argsort(-sent_scores, 1)

                            for i, idx in enumerate(selected_ids):
                                _pred = []
                                if (len(batch.src_str[i]) == 0):
                                    continue

                                for j in selected_ids[i][:len(batch.src_str[i]
                                                              )]:
                                    if (j >= len(batch.src_str[i])):
                                        continue
                                    candidate = batch.src_str[i][j].strip()
                                    if (self.args.block_trigram):
                                        if (not _block_tri(candidate, _pred)):
                                            _pred.append(candidate)
                                    else:
                                        _pred.append(candidate)

                                    if ((not cal_oracle)
                                            and (not self.args.recall_eval)
                                            and len(_pred) == 3):
                                        break

                                _pred = "<q>".join(_pred)
                                if (self.args.recall_eval):
                                    _pred = " ".join(
                                        _pred.split()
                                        [:len(batch.tgt_str[i].split())])

                                pred.append(_pred)
                                gold.append(batch.tgt_str[i])

                            for i in range(len(gold)):
                                save_gold.write(gold[i].strip() + "\n")
                            for i in range(len(pred)):
                                save_pred.write(pred[i].strip() + "\n")

        if (step != -1 and self.args.report_rouge):
            rouges = test_rouge(self.args.temp_dir, can_path, gold_path)
            logger.info("Rouges at step %d \n%s" %
                        (step, rouge_results_to_str(rouges)))

        self._report_step(0, step, valid_stats=stats)

        return stats
Example #9
0
    def train(self,
              train_iter_fct,
              train_steps,
              valid_iter_fct=None,
              valid_steps=-1):
        """
        The main training loops.
        by iterating over training data (i.e. `train_iter_fct`)
        and running validation (i.e. iterating over `valid_iter_fct`

        Args:
            train_iter_fct(function): a function that returns the train
                iterator. e.g. something like
                train_iter_fct = lambda: generator(*args, **kwargs)
            valid_iter_fct(function): same as train_iter_fct, for valid data
            train_steps(int):
            valid_steps(int):
            save_checkpoint_steps(int):

        Return:
            None
        """
        logger.info("| Start training ...")

        step = self.optim._step + 1
        true_batchs = []
        accum = 0
        normalization = 0
        train_iter = train_iter_fct()

        total_stats = Statistics()
        report_stats = Statistics()
        self._start_report_manager(start_time=total_stats.start_time)

        while step <= train_steps:
            reduce_counter = 0
            with tqdm(train_iter) as pbar:
                for i, batch in enumerate(pbar):
                    if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank):
                        true_batchs.append(batch)
                        normalization += batch.batch_size
                        accum += 1
                        if accum == self.grad_accum_count:
                            reduce_counter += 1
                            if self.n_gpu > 1:
                                normalization = sum(
                                    distributed.all_gather_list(normalization))
                            self._gradient_accumulation(
                                true_batchs, normalization, total_stats,
                                report_stats)

                            report_stats = self._maybe_report_training(
                                step, train_steps, self.optim.learning_rate,
                                report_stats)

                            true_batchs = []
                            accum = 0
                            normalization = 0
                            if (step % self.save_checkpoint_steps == 0
                                    and self.gpu_rank == 0):
                                self._save(step)

                            step += 1
                            if step > train_steps:
                                break
            train_iter = train_iter_fct()

        return total_stats