def reduce_metrics(self, logging_outputs, criterion): """Aggregate logging outputs from data parallel training.""" # backward compatibility for tasks that override aggregate_logging_outputs base_func = FairseqTask.aggregate_logging_outputs self_func = getattr(self, "aggregate_logging_outputs").__func__ if self_func is not base_func: utils.deprecation_warning( "Tasks should implement the reduce_metrics API. " "Falling back to deprecated aggregate_logging_outputs API." ) agg_logging_outputs = self.aggregate_logging_outputs( logging_outputs, criterion ) for k, v in agg_logging_outputs.items(): metrics.log_scalar(k, v) return if not any("ntokens" in log for log in logging_outputs): warnings.warn( "ntokens not found in Criterion logging outputs, cannot log wpb or wps" ) else: ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) metrics.log_scalar("wpb", ntokens, priority=180, round=1) metrics.log_speed("wps", ntokens, priority=90, round=1) if not any("nsentences" in log for log in logging_outputs): warnings.warn( "nsentences not found in Criterion logging outputs, cannot log bsz" ) else: nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) metrics.log_scalar("bsz", nsentences, priority=190, round=1) criterion.__class__.reduce_metrics(logging_outputs)
def reduce_metrics(self, logging_outputs, criterion): if not any('ntokens' in log for log in logging_outputs): warnings.warn('ntokens not found in Criterion logging outputs, cannot log wpb or wps') else: ntokens = utils.item(sum(log.get('ntokens', 0) for log in logging_outputs)) metrics.log_scalar('wpb', ntokens, priority=180, round=1) # TODO(urikz): Latest version of fairseq also has additional argument "ignore_first" metrics.log_speed('wps', ntokens, priority=90, round=1) if not any('nsentences' in log for log in logging_outputs): warnings.warn('nsentences not found in Criterion logging outputs, cannot log bsz') else: nsentences = utils.item(sum(log.get('nsentences', 0) for log in logging_outputs)) metrics.log_scalar('ns', nsentences, priority=190, round=1) if not any('sample_size' in log for log in logging_outputs): warnings.warn('sample_size not found in Criterion logging outputs, cannot log bsz') else: sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs)) metrics.log_scalar('bsz', sample_size, priority=190, round=1) if 'n_bor_instances' in logging_outputs[0].keys(): n_bor_instances = utils.item(sum(log.get('n_bor_instances', 0) for log in logging_outputs)) metrics.log_scalar('bsz_bor', n_bor_instances, priority=195, round=1) if 'ntokens_AB' in logging_outputs[0].keys(): ntokens_AB = utils.item(sum(log.get('ntokens_AB', 0) for log in logging_outputs)) metrics.log_scalar('wpb_AB', ntokens_AB, priority=200, round=1) if 'ntokens_mem' in logging_outputs[0].keys(): ntokens_mem = utils.item(sum(log.get('ntokens_mem', 0) for log in logging_outputs)) metrics.log_scalar('wpb_mem', ntokens_mem, priority=200, round=1) criterion.__class__.reduce_metrics(logging_outputs, self.split)
def reduce_metrics(self, logging_outputs, criterion): keys = set([ k for logging_output in logging_outputs for k in logging_output.keys() ]) for key in keys: if key.endswith('ntokens'): prefix = key[:-len('ntokens')] ntokens = utils.item( sum(log.get(key, 0) for log in logging_outputs)) metrics.log_scalar(prefix + 'wpb', ntokens, priority=80, round=1) metrics.log_speed(prefix + 'wps', ntokens, priority=50, round=1) elif key.endswith('nsentences'): prefix = key[:-len('nsentences')] nsentences = utils.item( sum(log.get(key, 0) for log in logging_outputs)) metrics.log_scalar(prefix + 'ns', nsentences, priority=90, round=1) elif key.endswith('sample_size'): prefix = key[:-len('sample_size')] sample_size = utils.item( sum(log.get(key, 0) for log in logging_outputs)) metrics.log_scalar(prefix + 'bsz', sample_size, priority=190, round=1) criterion.reduce_metrics(logging_outputs, self.split)
def train_step(self, samples, dummy_batch=False, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch is None: self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() if not dummy_batch: metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) ignore_grad = True else: ignore_grad = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if (self.args.distributed_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample, self.model, self.criterion, self.optimizer, ignore_grad) if not ignore_grad: logging_outputs.append(logging_output) sample_size += sample_size_i except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e print( "| WARNING: attempting to recover from OOM in forward/backward pass", file=sys.stderr, ) ooms += 1 self.zero_grad() else: raise e if ooms > 0 and self._oom_batch is not None: self.handle_ooms(ooms) if dummy_batch: return None # gather logging outputs from all replicas if self._sync_stats(): logging_outputs, sample_size, ooms = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, ) metrics.log_scalar("oom", ooms, len(samples), priority=600, round=3) if ooms == self.args.distributed_world_size * len(samples): print("| WARNING: OOM in all workers, skipping update") self.zero_grad() return None try: # normalize grads by sample size if not self.args.use_bmuf: # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). self.optimizer.multiply_grads( self.args.distributed_world_size / sample_size) elif sample_size > 0: # during non-sync gradients are divided by # sample_size whereas during sync (while calculating # global model): sync accumulate gradients and # divided by #GPUs and now multiply by #GPUs/#sample_size if self._sync_stats(): self.optimizer.multiply_grads( self.args.distributed_world_size / sample_size) else: self.optimizer.multiply_grads(1 / sample_size) # clip grads grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) # check that grad norms are consistent across workers if not self.args.use_bmuf: self._check_grad_norms(grad_norm) # take an optimization step self.optimizer.step() self.set_num_updates(self.get_num_updates() + 1) # task specific update per step self.task.update_step(self._num_updates) # log stats logging_output = self._reduce_and_log_stats( logging_outputs, sample_size) metrics.log_speed("ups", 1., priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) metrics.log_scalar( "clip", 100 if grad_norm > self.args.clip_norm > 0 else 0, priority=500, round=1, ) # clear CUDA cache to reduce memory fragmentation if (self.args.empty_cache_freq > 0 and ((self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq) == 0 and torch.cuda.is_available() and not self.args.cpu): torch.cuda.empty_cache() except OverflowError as e: print("| WARNING: overflow detected, " + str(e)) self.zero_grad() logging_output = None except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) print("| ERROR: OOM during optimization, irrecoverable") raise e if self.args.fp16: metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) self.clear_buffered_stats() metrics.log_stop_time("train_wall") return logging_output
def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch == "DUMMY": self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: is_dummy_batch = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if (self.args.distributed_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, criterion=self.criterion, optimizer=self.optimizer, ignore_grad=is_dummy_batch, ) del loss logging_outputs.append(logging_output) if not is_dummy_batch: sample_size += sample_size_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() else: raise e # gather logging outputs from all replicas if self._sync_stats(): logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, ignore=is_dummy_batch, ) metrics.log_scalar("oom", ooms, len(samples), priority=600, round=3) if ooms == self.args.distributed_world_size * len(samples): logger.warning("OOM in all workers, skipping update") self.zero_grad() return None try: # normalize grads by sample size if sample_size > 0: if self._sync_stats(): # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). self.optimizer.multiply_grads( self.args.distributed_world_size / sample_size) else: self.optimizer.multiply_grads(1 / sample_size) # clip grads grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) # check that grad norms are consistent across workers if not self.args.use_bmuf: self._check_grad_norms(grad_norm) # take an optimization step self.optimizer.step() self.set_num_updates(self.get_num_updates() + 1) # task specific update per step self.task.update_step(self.get_num_updates()) # log stats logging_output = self._reduce_and_log_stats( logging_outputs, sample_size) metrics.log_speed("ups", 1., ignore_first=10, priority=100, round=2) metrics.log_scalar("gnorm", utils.item(grad_norm), priority=400, round=3) metrics.log_scalar( "clip", 100 if grad_norm > self.args.clip_norm > 0 else 0, priority=500, round=1, ) # clear CUDA cache to reduce memory fragmentation if (self.args.empty_cache_freq > 0 and ((self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq) == 0 and torch.cuda.is_available() and not self.args.cpu): torch.cuda.empty_cache() except OverflowError as e: logger.info("NOTE: overflow detected, " + str(e)) self.zero_grad() logging_output = None except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e if self.args.fp16: metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) metrics.log_stop_time("train_wall") return logging_output