Ejemplo n.º 1
0
    def _reduce_and_log_stats(self,
                              logging_outputs,
                              sample_size,
                              grad_norm=None):
        if grad_norm is not None:
            metrics.log_speed("ups", 1., priority=100, round=2)
            metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
            if self.args.clip_norm > 0:
                metrics.log_scalar(
                    "clip",
                    torch.where(
                        grad_norm > self.args.clip_norm,
                        grad_norm.new_tensor(100),
                        grad_norm.new_tensor(0),
                    ),
                    priority=500,
                    round=1,
                )

        with metrics.aggregate() as agg:
            if logging_outputs is not None:
                self.task.reduce_metrics(logging_outputs, self.get_criterion())

            # support legacy interface
            logging_output = agg.get_smoothed_values()
            logging_output["sample_size"] = sample_size
            for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
                if key_to_delete in logging_output:
                    del logging_output[key_to_delete]
            return logging_output
Ejemplo n.º 2
0
    def _reduce_and_log_stats(self,
                              logging_outputs,
                              sample_size,
                              grad_norm=None):
        metrics.log_scalar("kl_loss",
                           round(logging_outputs[0]["kl_loss"].item(), 3))
        metrics.log_scalar("kld", round(logging_outputs[0]["kld"].item(), 3))
        metrics.log_scalar("bow_loss",
                           round(logging_outputs[0]["bow_loss"].item(), 3))

        if grad_norm is not None and (not torch.is_tensor(grad_norm)
                                      or torch.isfinite(grad_norm)):
            metrics.log_speed("ups", 1.0, priority=100, round=2)
            metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
            if self.cfg.optimization.clip_norm > 0:
                metrics.log_scalar(
                    "clip",
                    torch.where(
                        grad_norm > self.cfg.optimization.clip_norm,
                        grad_norm.new_tensor(100),
                        grad_norm.new_tensor(0),
                    ),
                    priority=500,
                    round=1,
                )

        with metrics.aggregate() as agg:

            if logging_outputs is not None:
                self.task.reduce_metrics(logging_outputs, self.get_criterion())
                del logging_outputs

            # extra warning for criterions that don't properly log a loss value
            if "loss" not in agg:
                if "loss" not in self._warn_once:
                    self._warn_once.add("loss")
                    logger.warning(
                        "Criterion.reduce_metrics did not log a 'loss' value, "
                        "which may break some functionality")
                metrics.log_scalar("loss", -1)
            # support legacy interface
            if self.tpu:
                logging_output = {}
            else:
                logging_output = agg.get_smoothed_values()
                logging_output["sample_size"] = sample_size
                for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
                    if key_to_delete in logging_output:
                        del logging_output[key_to_delete]
            return logging_output
Ejemplo n.º 3
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                is_dummy_batch = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.args.distributed_world_size > 1
                        and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        update_num=self.get_num_updates(),
                        ignore_grad=is_dummy_batch,
                    )
                    del loss

                logging_outputs.append(logging_output)
                if not is_dummy_batch:
                    sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    logger.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

        # gather logging outputs from all replicas
        if self._sync_stats():
            logging_outputs, (sample_size,
                              ooms) = self._aggregate_logging_outputs(
                                  logging_outputs,
                                  sample_size,
                                  ooms,
                                  ignore=is_dummy_batch,
                              )

        metrics.log_scalar("oom", ooms, len(samples), priority=600, round=3)
        if ooms == self.args.distributed_world_size * len(samples):
            logger.warning("OOM in all workers, skipping update")
            self.zero_grad()
            return None

        try:
            # normalize grads by sample size
            if sample_size > 0:
                if self._sync_stats():
                    # multiply gradients by (# GPUs / sample_size) since DDP
                    # already normalizes by the number of GPUs. Thus we get
                    # (sum_of_gradients / sample_size).
                    self.optimizer.multiply_grads(
                        self.args.distributed_world_size / sample_size)
                else:
                    self.optimizer.multiply_grads(1 / sample_size)

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)

            # check that grad norms are consistent across workers
            if not self.args.use_bmuf:
                self._check_grad_norms(grad_norm)

            # take an optimization step
            self.optimizer.step()
            self.set_num_updates(self.get_num_updates() + 1)

            # log stats
            logging_output = self._reduce_and_log_stats(
                logging_outputs, sample_size)
            metrics.log_speed("ups",
                              1.,
                              ignore_first=10,
                              priority=100,
                              round=2)
            metrics.log_scalar("gnorm",
                               utils.item(grad_norm),
                               priority=400,
                               round=3)
            metrics.log_scalar(
                "clip",
                100 if grad_norm > self.args.clip_norm > 0 else 0,
                priority=500,
                round=1,
            )

            # clear CUDA cache to reduce memory fragmentation
            if (self.args.empty_cache_freq > 0 and
                ((self.get_num_updates() + self.args.empty_cache_freq - 1) %
                 self.args.empty_cache_freq) == 0
                    and torch.cuda.is_available() and not self.args.cpu):
                torch.cuda.empty_cache()
        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print out where it fails
            with NanDetector(self.model):
                self.task.train_step(sample,
                                     self.model,
                                     self.criterion,
                                     self.optimizer,
                                     ignore_grad=False)
            raise
        except OverflowError as e:
            logger.info("NOTE: overflow detected, " + str(e))
            self.zero_grad()
            logging_output = None
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                logger.error("OOM during optimization, irrecoverable")
            raise e

        if self.args.fp16:
            metrics.log_scalar("loss_scale",
                               self.optimizer.scaler.loss_scale,
                               priority=700,
                               round=0)

        metrics.log_stop_time("train_wall")

        return logging_output
Ejemplo n.º 4
0
    def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None, gvar=None, adam_mom2=None,
                              gvar_diff=None, xstd=None, ams_mom=None, acc_ratio=None, real_var=None, real_var_diff=None,
                              ad_beta=None, lr_min=None, lr_max=None, lr_median=None,
                              update_min=None, update_max=None, update_median=None, valid_ratio=None, var_adapt=None):
        if grad_norm is not None:
            metrics.log_speed("ups", 1., priority=100, round=2)
            metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
            if self.args.clip_norm > 0:
                metrics.log_scalar(
                    "clip",
                    torch.where(
                        grad_norm > self.args.clip_norm,
                        grad_norm.new_tensor(100),
                        grad_norm.new_tensor(0),
                    ),
                    priority=500,
                    round=1,
                )
        if gvar is not None:
            metrics.log_scalar("gvar", gvar, priority=100)
        if adam_mom2 is not None:
            metrics.log_scalar("adam_mom2", adam_mom2, priority=100)
        if gvar_diff is not None:
            metrics.log_scalar("gvar_diff", gvar_diff, priority=100)
        if xstd is not None:
            metrics.log_scalar("xstd", xstd, priority=100)
        if ams_mom is not None:
            metrics.log_scalar("ams_mom", ams_mom, priority=100)
        if acc_ratio is not None:
            metrics.log_scalar("acc_ratio", acc_ratio, priority=50)
        if real_var is not None:
            metrics.log_scalar("real_var", real_var, priority=50)
        if real_var_diff is not None:
            metrics.log_scalar("real_var_diff", real_var_diff, priority=50)
        if ad_beta is not None:
            metrics.log_scalar("ad_beta", ad_beta, priority=50)

        if lr_min is not None:
            metrics.log_scalar("lr_min", lr_min, priority=50)

        if lr_max is not None:
            metrics.log_scalar("lr_max", lr_max, priority=50)

        if lr_median is not None:
            metrics.log_scalar("lr_median", lr_median, priority=50)
        if update_min is not None:
            metrics.log_scalar("update_min", update_min, priority=50)

        if update_median is not None:
            metrics.log_scalar("update_median", update_median, priority=50)

        if update_max is not None:
            metrics.log_scalar("update_max", update_max, priority=50)

        if valid_ratio is not None:
            metrics.log_scalar("valid_ratio", valid_ratio, priority=49)
        if var_adapt is not None:
            metrics.log_scalar("var_adapt", var_adapt, priority=1)

        with metrics.aggregate() as agg:
            if logging_outputs is not None:
                self.task.reduce_metrics(logging_outputs, self.get_criterion())

            preds, labels = [], []
            for log_output in logging_outputs:
                if 'preds' in log_output:
                    preds.append(log_output['preds'])
                    labels.append(log_output['labels'])
                else:
                    preds = None
                    labels = None

            # support legacy interface
            logging_output = agg.get_smoothed_values()
            logging_output["sample_size"] = sample_size
            for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
                if key_to_delete in logging_output:
                    del logging_output[key_to_delete]
            return logging_output, preds, labels