Esempio n. 1
0
 def set_num_updates(self, num_updates):
     """Set the number of parameters updates."""
     self._num_updates = num_updates
     self.lr_step_update()
     if self.quantizer:
         self.quantizer.step_update(self._num_updates)
     metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)
Esempio n. 2
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)

        metrics.log_scalar('loss', loss_sum / sample_size, sample_size, round=3)

        if 'ce_loss' in logging_outputs[0]:
            ce_loss = sum(log['ce_loss'] for log in logging_outputs) / ntokens
            metrics.log_scalar('ce_loss', ce_loss, ntokens, round=3)
        if 'qua_loss' in logging_outputs[0]:
            qua_loss = sum(log['qua_loss'] for log in logging_outputs) / nsentences
            metrics.log_scalar('qua_loss', qua_loss, nsentences, round=3)

        c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_c_errors", c_errors)
        c_total = sum(log.get("c_total", 0) for log in logging_outputs)
        metrics.log_scalar("_c_total", c_total)

        if c_total > 0:
            metrics.log_derived(
                "uer",
                lambda meters: safe_round(meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3)
                if meters["_c_total"].sum > 0
                else float("nan"),
            )
Esempio n. 3
0
 def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
     """Aggregate logging outputs from data parallel training."""
     utils.deprecation_warning(
         'Criterions should implement the reduce_metrics API. '
         'Falling back to deprecated aggregate_logging_outputs API.')
     agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
     for k, v in agg_logging_outputs.items():
         if k in {'nsentences', 'ntokens', 'sample_size'}:
             continue
         metrics.log_scalar(k, v)
Esempio n. 4
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)

        metrics.log_scalar('loss',
                           loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=3)
        metrics.log_derived(
            'ppl', lambda meters: utils.get_perplexity(meters['loss'].avg))
Esempio n. 5
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = utils.item(sum(log.get('loss', 0) for log in logging_outputs))
        ntokens = utils.item(sum(log.get('ntokens', 0) for log in logging_outputs))
        sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs))

        metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
        if sample_size != ntokens:
            metrics.log_scalar('nll_loss', loss_sum / ntokens / math.log(2), ntokens, round=3)
            metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg))
        else:
            metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['loss'].avg))
Esempio n. 6
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        lm_loss_sum = sum(log.get('lm_loss', 0) for log in logging_outputs)
        sentence_loss_sum = sum(
            log.get('sentence_loss', 0) for log in logging_outputs)
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
        agg_loss = sum(log.get('loss', 0) for log in logging_outputs)

        metrics.log_scalar('loss',
                           agg_loss / sample_size /
                           math.log(2) if sample_size > 0 else 0.,
                           sample_size,
                           round=3)
        metrics.log_scalar('lm_loss',
                           lm_loss_sum / ntokens /
                           math.log(2) if ntokens > 0 else 0.,
                           ntokens,
                           round=3)
        metrics.log_scalar('sentence_loss',
                           sentence_loss_sum / nsentences /
                           math.log(2) if nsentences > 0 else 0.,
                           nsentences,
                           round=3)
        metrics.log_scalar('nll_loss',
                           lm_loss_sum / ntokens /
                           math.log(2) if ntokens > 0 else 0.,
                           ntokens,
                           round=3)
Esempio n. 7
0
    def reduce_metrics(self, logging_outputs, criterion):
        """Aggregate logging outputs from data parallel training."""
        # backward compatibility for tasks that override aggregate_logging_outputs
        base_func = FairseqTask.aggregate_logging_outputs
        self_func = getattr(self, "aggregate_logging_outputs").__func__
        if self_func is not base_func:
            utils.deprecation_warning(
                "Tasks should implement the reduce_metrics API. "
                "Falling back to deprecated aggregate_logging_outputs API.")
            agg_logging_outputs = self.aggregate_logging_outputs(
                logging_outputs, criterion)
            for k, v in agg_logging_outputs.items():
                metrics.log_scalar(k, v)
            return

        if not any("ntokens" in log for log in logging_outputs):
            warnings.warn(
                "ntokens not found in Criterion logging outputs, cannot log wpb or wps"
            )
        else:
            ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
            metrics.log_scalar("wpb", ntokens, priority=180, round=1)
            metrics.log_speed("wps", ntokens, priority=90, round=1)

        if not any("nsentences" in log for log in logging_outputs):
            warnings.warn(
                "nsentences not found in Criterion logging outputs, cannot log bsz"
            )
        else:
            nsentences = sum(
                log.get("nsentences", 0) for log in logging_outputs)
            metrics.log_scalar("bsz", nsentences, priority=190, round=1)

        criterion.__class__.reduce_metrics(logging_outputs)
Esempio n. 8
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        sample_size = utils.item(
            sum(log.get("sample_size", 0) for log in logging_outputs))
        loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
        nll_loss = utils.item(
            sum(log.get("nll_loss", 0) for log in logging_outputs))

        metrics.log_scalar('loss',
                           loss / sample_size / math.log(2),
                           sample_size,
                           round=3)
        metrics.log_scalar('nll_loss',
                           nll_loss / sample_size / math.log(2),
                           sample_size,
                           round=3)
        metrics.log_derived(
            'ppl', lambda meters: utils.get_perplexity(meters['loss'].avg))

        for key in logging_outputs[0]:
            if key[-5:] == "-loss":
                val = sum(log.get(key, 0) for log in logging_outputs)
                metrics.log_scalar(
                    key[:-5],
                    val / sample_size /
                    math.log(2) if sample_size > 0 else 0.0,
                    sample_size,
                    round=3,
                )
Esempio n. 9
0
    def _reduce_and_log_stats(self,
                              logging_outputs,
                              sample_size,
                              grad_norm=None):
        if grad_norm is not None:
            metrics.log_speed("ups", 1., priority=100, round=2)
            metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
            if self.args.clip_norm > 0:
                metrics.log_scalar(
                    "clip",
                    torch.where(
                        grad_norm > self.args.clip_norm,
                        grad_norm.new_tensor(100),
                        grad_norm.new_tensor(0),
                    ),
                    priority=500,
                    round=1,
                )

        with metrics.aggregate() as agg:
            if logging_outputs is not None:
                self.task.reduce_metrics(logging_outputs, self.get_criterion())
                del logging_outputs

            # extra warning for criterions that don't properly log a loss value
            if "loss" not in agg:
                if "loss" not in self._warn_once:
                    self._warn_once.add("loss")
                    logger.warning(
                        "Criterion.reduce_metrics did not log a 'loss' value, "
                        "which may break some functionality")
                metrics.log_scalar("loss", -1)

            # support legacy interface
            logging_output = agg.get_smoothed_values()
            logging_output["sample_size"] = sample_size
            for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
                if key_to_delete in logging_output:
                    del logging_output[key_to_delete]
            return logging_output
Esempio n. 10
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)

        metrics.log_scalar('loss',
                           loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=3)
        if sample_size != ntokens:
            metrics.log_scalar('nll_loss',
                               loss_sum / ntokens / math.log(2),
                               ntokens,
                               round=3)

        if len(logging_outputs) > 0 and 'ncorrect' in logging_outputs[0]:
            ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs)
            metrics.log_scalar('accuracy',
                               100.0 * ncorrect / nsentences,
                               nsentences,
                               round=1)
Esempio n. 11
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""

        loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
        ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
        nsentences = utils.item(
            sum(log.get("nsentences", 0) for log in logging_outputs)
        )
        sample_size = utils.item(
            sum(log.get("sample_size", 0) for log in logging_outputs)
        )

        metrics.log_scalar(
            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        metrics.log_scalar("ntokens", ntokens)
        metrics.log_scalar("nsentences", nsentences)
        if sample_size != ntokens:
            metrics.log_scalar(
                "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
            )

        c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_c_errors", c_errors)
        c_total = sum(log.get("c_total", 0) for log in logging_outputs)
        metrics.log_scalar("_c_total", c_total)
        w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_w_errors", w_errors)
        wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_wv_errors", wv_errors)
        w_total = sum(log.get("w_total", 0) for log in logging_outputs)
        metrics.log_scalar("_w_total", w_total)

        if c_total > 0:
            metrics.log_derived(
                "uer",
                lambda meters: safe_round(meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3)
                if meters["_c_total"].sum > 0
                else float("nan"),
            )
        if w_total > 0:
            metrics.log_derived(
                "wer",
                lambda meters: safe_round(meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3)
                if meters["_w_total"].sum > 0
                else float("nan"),
            )
            metrics.log_derived(
                "raw_wer",
                lambda meters: safe_round(meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3)
                if meters["_w_total"].sum > 0
                else float("nan"),
            )
Esempio n. 12
0
 def lr_step_update(self):
     """Update the learning rate after each update."""
     new_lr = self.lr_scheduler.step_update(self.get_num_updates())
     metrics.log_scalar("lr", new_lr, weight=0, priority=300)
     return new_lr
Esempio n. 13
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                is_dummy_batch = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.data_parallel_world_size > 1
                        and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        update_num=self.get_num_updates(),
                        ignore_grad=is_dummy_batch,
                    )
                    del loss

                logging_outputs.append(logging_output)
                sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    logger.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                    if self.cuda:
                        torch.cuda.empty_cache()
                    if self.args.distributed_world_size == 1:
                        return None
                else:
                    raise e

        if is_dummy_batch:
            if torch.is_tensor(sample_size):
                sample_size.zero_()
            else:
                sample_size *= 0.

        if torch.is_tensor(sample_size):
            sample_size = sample_size.float()
        else:
            sample_size = float(sample_size)

        # gather logging outputs from all replicas
        if self._sync_stats():
            train_time = self._local_cumulative_training_time()
            logging_outputs, (
                sample_size, ooms,
                total_train_time) = self._aggregate_logging_outputs(
                    logging_outputs,
                    sample_size,
                    ooms,
                    train_time,
                    ignore=is_dummy_batch,
                )
            self._cumulative_training_time = total_train_time / self.data_parallel_world_size

        if hasattr(self.model, 'all_reduce'):
            self.model.all_reduce()

        overflow = False
        try:
            with torch.autograd.profiler.record_function("multiply-grads"):
                # multiply gradients by (# GPUs / sample_size) since DDP
                # already normalizes by the number of GPUs. Thus we get
                # (sum_of_gradients / sample_size).
                if not self.args.use_bmuf:
                    self.optimizer.multiply_grads(
                        self.data_parallel_world_size / sample_size)
                elif sample_size > 0:  # BMUF needs to check sample size
                    num = self.data_parallel_world_size if self._sync_stats(
                    ) else 1
                    self.optimizer.multiply_grads(num / sample_size)

            with torch.autograd.profiler.record_function("clip-grads"):
                # clip grads
                grad_norm = self.clip_grad_norm(self.args.clip_norm)

            # check that grad norms are consistent across workers
            if (not self.args.use_bmuf
                    and self.args.distributed_wrapper != 'SlowMo'):
                self._check_grad_norms(grad_norm)

            with torch.autograd.profiler.record_function("optimizer"):
                # take an optimization step
                self.optimizer.step()
        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print
            # out where it fails
            with NanDetector(self.model):
                self.task.train_step(sample,
                                     self.model,
                                     self.criterion,
                                     self.optimizer,
                                     self.get_num_updates(),
                                     ignore_grad=False)
            raise
        except OverflowError as e:
            overflow = True
            logger.info("NOTE: overflow detected, " + str(e))
            grad_norm = torch.tensor(0.).cuda()
            self.zero_grad()
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                logger.error("OOM during optimization, irrecoverable")
            raise e

        # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step
        if hasattr(self.model, 'perform_additional_optimizer_actions'):
            if hasattr(self.optimizer, 'fp32_params'):
                self.model.perform_additional_optimizer_actions(
                    self.optimizer.optimizer, self.optimizer.fp32_params)
            else:
                self.model.perform_additional_optimizer_actions(
                    self.optimizer.optimizer)

        if not overflow or self.args.distributed_wrapper == 'SlowMo':
            self.set_num_updates(self.get_num_updates() + 1)

            # log stats
            logging_output = self._reduce_and_log_stats(
                logging_outputs,
                sample_size,
                grad_norm,
            )

            # clear CUDA cache to reduce memory fragmentation
            if (self.cuda and self.args.empty_cache_freq > 0 and
                ((self.get_num_updates() + self.args.empty_cache_freq - 1) %
                 self.args.empty_cache_freq) == 0):
                torch.cuda.empty_cache()

        if self.args.fp16:
            metrics.log_scalar(
                "loss_scale",
                self.optimizer.scaler.loss_scale,
                priority=700,
                round=4,
                weight=0,
            )

        metrics.log_stop_time("train_wall")

        return logging_output
Esempio n. 14
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = utils.item(sum(log.get('loss', 0) for log in logging_outputs))
        ntokens = utils.item(sum(log.get('ntokens', 0) for log in logging_outputs))
        nsentences = utils.item(sum(log.get('nsentences', 0) for log in logging_outputs))
        sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs))

        metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
        metrics.log_scalar('ntokens', ntokens)
        metrics.log_scalar('nsentences', nsentences)

        correct = sum(log.get("correct", 0) for log in logging_outputs)
        metrics.log_scalar("_correct", correct)

        total = sum(log.get("count", 0) for log in logging_outputs)
        metrics.log_scalar("_total", total)


        if total > 0:
            metrics.log_derived(
                "accuracy",
                lambda meters: safe_round(meters["_correct"].sum / meters["_total"].sum, 5)
                if meters["_total"].sum > 0
                else float("nan"),
            )

        builtin_keys = {'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count'}

        for k in logging_outputs[0]:
            if k not in builtin_keys:
                val = sum(log.get(k, 0) for log in logging_outputs) / len(logging_outputs)
                if k.startswith('loss'):
                    metrics.log_scalar(k, val / sample_size / math.log(2), sample_size)
                else:
                    metrics.log_scalar(k, val, round=3)