def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): if grad_norm is not None: metrics.log_speed("ups", 1., priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) if self.args.clip_norm > 0: metrics.log_scalar( "clip", torch.where( grad_norm > self.args.clip_norm, grad_norm.new_tensor(100), grad_norm.new_tensor(0), ), priority=500, round=1, ) with metrics.aggregate() as agg: if logging_outputs is not None: self.task.reduce_metrics(logging_outputs, self.get_criterion()) # support legacy interface logging_output = agg.get_smoothed_values() logging_output["sample_size"] = sample_size for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: if key_to_delete in logging_output: del logging_output[key_to_delete] return logging_output
def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): metrics.log_scalar("kl_loss", round(logging_outputs[0]["kl_loss"].item(), 3)) metrics.log_scalar("kld", round(logging_outputs[0]["kld"].item(), 3)) metrics.log_scalar("bow_loss", round(logging_outputs[0]["bow_loss"].item(), 3)) if grad_norm is not None and (not torch.is_tensor(grad_norm) or torch.isfinite(grad_norm)): metrics.log_speed("ups", 1.0, priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) if self.cfg.optimization.clip_norm > 0: metrics.log_scalar( "clip", torch.where( grad_norm > self.cfg.optimization.clip_norm, grad_norm.new_tensor(100), grad_norm.new_tensor(0), ), priority=500, round=1, ) with metrics.aggregate() as agg: if logging_outputs is not None: self.task.reduce_metrics(logging_outputs, self.get_criterion()) del logging_outputs # extra warning for criterions that don't properly log a loss value if "loss" not in agg: if "loss" not in self._warn_once: self._warn_once.add("loss") logger.warning( "Criterion.reduce_metrics did not log a 'loss' value, " "which may break some functionality") metrics.log_scalar("loss", -1) # support legacy interface if self.tpu: logging_output = {} else: logging_output = agg.get_smoothed_values() logging_output["sample_size"] = sample_size for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: if key_to_delete in logging_output: del logging_output[key_to_delete] return logging_output
def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch == "DUMMY": self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: is_dummy_batch = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if (self.args.distributed_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, criterion=self.criterion, optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, ) del loss logging_outputs.append(logging_output) if not is_dummy_batch: sample_size += sample_size_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() else: raise e # gather logging outputs from all replicas if self._sync_stats(): logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, ignore=is_dummy_batch, ) metrics.log_scalar("oom", ooms, len(samples), priority=600, round=3) if ooms == self.args.distributed_world_size * len(samples): logger.warning("OOM in all workers, skipping update") self.zero_grad() return None try: # normalize grads by sample size if sample_size > 0: if self._sync_stats(): # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). self.optimizer.multiply_grads( self.args.distributed_world_size / sample_size) else: self.optimizer.multiply_grads(1 / sample_size) # clip grads grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) # check that grad norms are consistent across workers if not self.args.use_bmuf: self._check_grad_norms(grad_norm) # take an optimization step self.optimizer.step() self.set_num_updates(self.get_num_updates() + 1) # log stats logging_output = self._reduce_and_log_stats( logging_outputs, sample_size) metrics.log_speed("ups", 1., ignore_first=10, priority=100, round=2) metrics.log_scalar("gnorm", utils.item(grad_norm), priority=400, round=3) metrics.log_scalar( "clip", 100 if grad_norm > self.args.clip_norm > 0 else 0, priority=500, round=1, ) # clear CUDA cache to reduce memory fragmentation if (self.args.empty_cache_freq > 0 and ((self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq) == 0 and torch.cuda.is_available() and not self.args.cpu): torch.cuda.empty_cache() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print out where it fails with NanDetector(self.model): self.task.train_step(sample, self.model, self.criterion, self.optimizer, ignore_grad=False) raise except OverflowError as e: logger.info("NOTE: overflow detected, " + str(e)) self.zero_grad() logging_output = None except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e if self.args.fp16: metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) metrics.log_stop_time("train_wall") return logging_output
def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None, gvar=None, adam_mom2=None, gvar_diff=None, xstd=None, ams_mom=None, acc_ratio=None, real_var=None, real_var_diff=None, ad_beta=None, lr_min=None, lr_max=None, lr_median=None, update_min=None, update_max=None, update_median=None, valid_ratio=None, var_adapt=None): if grad_norm is not None: metrics.log_speed("ups", 1., priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) if self.args.clip_norm > 0: metrics.log_scalar( "clip", torch.where( grad_norm > self.args.clip_norm, grad_norm.new_tensor(100), grad_norm.new_tensor(0), ), priority=500, round=1, ) if gvar is not None: metrics.log_scalar("gvar", gvar, priority=100) if adam_mom2 is not None: metrics.log_scalar("adam_mom2", adam_mom2, priority=100) if gvar_diff is not None: metrics.log_scalar("gvar_diff", gvar_diff, priority=100) if xstd is not None: metrics.log_scalar("xstd", xstd, priority=100) if ams_mom is not None: metrics.log_scalar("ams_mom", ams_mom, priority=100) if acc_ratio is not None: metrics.log_scalar("acc_ratio", acc_ratio, priority=50) if real_var is not None: metrics.log_scalar("real_var", real_var, priority=50) if real_var_diff is not None: metrics.log_scalar("real_var_diff", real_var_diff, priority=50) if ad_beta is not None: metrics.log_scalar("ad_beta", ad_beta, priority=50) if lr_min is not None: metrics.log_scalar("lr_min", lr_min, priority=50) if lr_max is not None: metrics.log_scalar("lr_max", lr_max, priority=50) if lr_median is not None: metrics.log_scalar("lr_median", lr_median, priority=50) if update_min is not None: metrics.log_scalar("update_min", update_min, priority=50) if update_median is not None: metrics.log_scalar("update_median", update_median, priority=50) if update_max is not None: metrics.log_scalar("update_max", update_max, priority=50) if valid_ratio is not None: metrics.log_scalar("valid_ratio", valid_ratio, priority=49) if var_adapt is not None: metrics.log_scalar("var_adapt", var_adapt, priority=1) with metrics.aggregate() as agg: if logging_outputs is not None: self.task.reduce_metrics(logging_outputs, self.get_criterion()) preds, labels = [], [] for log_output in logging_outputs: if 'preds' in log_output: preds.append(log_output['preds']) labels.append(log_output['labels']) else: preds = None labels = None # support legacy interface logging_output = agg.get_smoothed_values() logging_output["sample_size"] = sample_size for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: if key_to_delete in logging_output: del logging_output[key_to_delete] return logging_output, preds, labels