def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch == "DUMMY": self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: is_dummy_batch = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if ( self.data_parallel_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1 ): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, criterion=self.criterion, optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, ) del loss logging_outputs.append(logging_output) sample_size += sample_size_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() else: raise e if self.tpu and i < len(samples) - 1: # tpu-comment: every XLA operation before marking step is # appended to the IR graph, and processing too many batches # before marking step can lead to OOM errors. # To handle gradient accumulation use case, we explicitly # mark step here for every forward pass without a backward pass import torch_xla.core.xla_model as xm xm.mark_step() if is_dummy_batch: if torch.is_tensor(sample_size): sample_size.zero_() else: sample_size *= 0. if torch.is_tensor(sample_size): sample_size = sample_size.float() else: sample_size = float(sample_size) # gather logging outputs from all replicas if self._sync_stats(): logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, ignore=is_dummy_batch, ) overflow = False try: if self.tpu and self.data_parallel_world_size > 1: import torch_xla.core.xla_model as xm gradients = xm._fetch_gradients(self.optimizer.optimizer) xm.all_reduce('sum', gradients, scale=1.0 / self.data_parallel_world_size) # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). if not self.args.use_bmuf: self.optimizer.multiply_grads(self.data_parallel_world_size / sample_size) elif sample_size > 0: # BMUF needs to check sample size num = self.data_parallel_world_size if self._sync_stats() else 1 self.optimizer.multiply_grads(num / sample_size) # clip grads grad_norm = self.clip_grad_norm(self.args.clip_norm) # check that grad norms are consistent across workers if ( not self.args.use_bmuf and self.args.distributed_wrapper != 'SlowMo' and not self.tpu ): self._check_grad_norms(grad_norm) # take an optimization step self.optimizer.step() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print out where it fails with NanDetector(self.model): self.task.train_step( sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), ignore_grad=False ) raise except OverflowError as e: overflow = True logger.info("NOTE: overflow detected, " + str(e)) grad_norm = torch.tensor(0.).cuda() self.zero_grad() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step if hasattr(self.model, 'perform_additional_optimizer_actions'): if hasattr(self.optimizer, 'fp32_params'): self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params) else: self.model.perform_additional_optimizer_actions(self.optimizer.optimizer) if not overflow or self.args.distributed_wrapper == 'SlowMo': self.set_num_updates(self.get_num_updates() + 1) if self.tpu: # mark step on TPUs import torch_xla.core.xla_model as xm xm.mark_step() # only log stats every log_interval steps # this causes wps to be misreported when log_interval > 1 logging_output = {} if self.get_num_updates() % self.args.log_interval == 0: logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # log whenever there's an XLA compilation, since these # slow down training and may indicate opportunities for # optimization self._check_xla_compilation() else: # log stats logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # clear CUDA cache to reduce memory fragmentation if ( self.cuda and self.args.empty_cache_freq > 0 and ( (self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq ) == 0 ): torch.cuda.empty_cache() if self.args.fp16: metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) metrics.log_stop_time("train_wall") return logging_output
def lr_step_update(self): """Update the learning rate after each update.""" new_lr = self.lr_scheduler.step_update(self.get_num_updates()) metrics.log_scalar("lr", new_lr, weight=0, priority=300) return new_lr
def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" CrossEntropyCriterion.reduce_metrics(logging_outputs) num_corr = sum(log.get("num_corr", 0) for log in logging_outputs) num_tot = sum(log.get("num_tot", 0) for log in logging_outputs) metrics.log_scalar("accuracy", num_corr.float() / num_tot * 100 if num_tot > 0 else 0.0, num_tot, round=3)
def set_num_updates(self, num_updates): """Set the number of parameters updates.""" self._num_updates = num_updates self.lr_step_update() metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)
def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch == "DUMMY": self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: is_dummy_batch = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if ( self.args.distributed_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1 ): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, criterion=self.criterion, optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, ) del loss logging_outputs.append(logging_output) if not is_dummy_batch: sample_size += sample_size_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() else: raise e # gather logging outputs from all replicas if self._sync_stats(): logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, ignore=is_dummy_batch, ) metrics.log_scalar("oom", ooms, len(samples), priority=600, round=3) if ooms == self.args.distributed_world_size * len(samples): logger.warning("OOM in all workers, skipping update") self.zero_grad() return None try: # normalize grads by sample size if sample_size > 0: if self._sync_stats(): # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). self.optimizer.multiply_grads(self.args.distributed_world_size / sample_size) else: self.optimizer.multiply_grads(1 / sample_size) # clip grads grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) # check that grad norms are consistent across workers if not self.args.use_bmuf: self._check_grad_norms(grad_norm) # take an optimization step self.optimizer.step() self.set_num_updates(self.get_num_updates() + 1) # log stats logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) metrics.log_speed("ups", 1., ignore_first=10, priority=100, round=2) metrics.log_scalar("gnorm", utils.item(grad_norm), priority=400, round=3) metrics.log_scalar( "clip", 100 if grad_norm > self.args.clip_norm > 0 else 0, priority=500, round=1, ) # clear CUDA cache to reduce memory fragmentation if ( self.args.empty_cache_freq > 0 and ( (self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq ) == 0 and torch.cuda.is_available() and not self.args.cpu ): torch.cuda.empty_cache() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print out where it fails with NanDetector(self.model): self.task.train_step( sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), ignore_grad=False ) raise except OverflowError as e: logger.info("NOTE: overflow detected, " + str(e)) self.zero_grad() logging_output = None except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e if self.args.fp16: metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) metrics.log_stop_time("train_wall") return logging_output
def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch == "DUMMY": self._dummy_batch = samples[0] self._set_seed() #print("MODEL TRAIN") self.model.train() #print("CRITERION TRAIN") self.criterion.train() self.zero_grad() metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 #print(samples) for i, sample in enumerate(samples): #print(i) sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: is_dummy_batch = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if (self.data_parallel_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, criterion=self.criterion, optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, ) del loss logging_outputs.append(logging_output) sample_size += sample_size_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() else: raise e if torch.is_tensor(sample_size): sample_size = sample_size.float() else: sample_size = float(sample_size) if is_dummy_batch: sample_size *= 0. # multiply by 0 to preserve device # gather logging outputs from all replicas if self._sync_stats(): logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, ignore=is_dummy_batch, ) overflow = False try: # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). if not self.args.use_bmuf: multiplier = self.data_parallel_world_size self.optimizer.multiply_grads(multiplier / sample_size) elif sample_size > 0: # BMUF needs to check sample size num = self.data_parallel_world_size if self._sync_stats( ) else 1 self.optimizer.multiply_grads(num / sample_size) # clip grads grad_norm = self.clip_grad_norm(self.args.clip_norm) # check that grad norms are consistent across workers if not self.args.use_bmuf and self.args.distributed_wrapper != 'SlowMo': self._check_grad_norms(grad_norm) # take an optimization step self.optimizer.step() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print out where it fails with NanDetector(self.model): self.task.train_step(sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), ignore_grad=False) raise except OverflowError as e: overflow = True logger.info("NOTE: overflow detected, " + str(e)) grad_norm = torch.tensor(0.).cuda() self.zero_grad() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step if hasattr(self.model, 'perform_additional_optimizer_actions'): if hasattr(self.optimizer, 'fp32_params'): self.model.perform_additional_optimizer_actions( self.optimizer.optimizer, self.optimizer.fp32_params) else: self.model.perform_additional_optimizer_actions( self.optimizer.optimizer) if not overflow or self.args.distributed_wrapper == 'SlowMo': self.set_num_updates(self.get_num_updates() + 1) # log stats logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # clear CUDA cache to reduce memory fragmentation if (self.args.empty_cache_freq > 0 and ((self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq) == 0 and torch.cuda.is_available() and not self.args.cpu): torch.cuda.empty_cache() if self.args.fp16: metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) metrics.log_stop_time("train_wall") return logging_output
def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" self._set_seed() self.model.train() self.criterion.train() self.zero_grad() metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): sample, is_dummy_batch = self._prepare_sample(sample) def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if ( self.data_parallel_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1 ): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, criterion=self.criterion, optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, ) del loss logging_outputs.append(logging_output) sample_size += sample_size_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() if self.cuda: torch.cuda.empty_cache() if self.cfg.distributed_training.distributed_world_size == 1: return None else: raise e if self.tpu and i < len(samples) - 1: # tpu-comment: every XLA operation before marking step is # appended to the IR graph, and processing too many batches # before marking step can lead to OOM errors. # To handle gradient accumulation use case, we explicitly # mark step here for every forward pass without a backward pass import torch_xla.core.xla_model as xm xm.mark_step() if is_dummy_batch: if torch.is_tensor(sample_size): sample_size.zero_() else: sample_size *= 0.0 if torch.is_tensor(sample_size): sample_size = sample_size.float() else: sample_size = float(sample_size) # gather logging outputs from all replicas if self._sync_stats(): train_time = self._local_cumulative_training_time() logging_outputs, ( sample_size, ooms, total_train_time, ) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch, ) self._cumulative_training_time = ( total_train_time / self.data_parallel_world_size ) overflow = False try: with torch.autograd.profiler.record_function("reduce-grads"): self.optimizer.all_reduce_grads(self.model) if utils.has_parameters(self.criterion): self.optimizer.all_reduce_grads(self.criterion) with torch.autograd.profiler.record_function("multiply-grads"): # multiply gradients by (data_parallel_size / sample_size) since # DDP already normalizes by the number of data parallel workers. # Thus we get (sum_of_gradients / sample_size) at the end. if not self.cfg.optimization.use_bmuf: self.optimizer.multiply_grads( self.data_parallel_world_size / sample_size ) elif sample_size > 0: # BMUF needs to check sample size num = self.data_parallel_world_size if self._sync_stats() else 1 self.optimizer.multiply_grads(num / sample_size) with torch.autograd.profiler.record_function("clip-grads"): # clip grads grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm) # check that grad norms are consistent across workers # on tpu check tensor is slow if not self.tpu: if ( not self.cfg.optimization.use_bmuf and self.cfg.distributed_training.distributed_wrapper != "SlowMo" ): self._check_grad_norms(grad_norm) if not torch.isfinite(grad_norm).all(): # check local gradnorm single GPU case, trigger NanDetector raise FloatingPointError("gradients are Nan/Inf") with torch.autograd.profiler.record_function("optimizer"): # take an optimization step self.task.optimizer_step( self.optimizer, model=self.model, update_num=self.get_num_updates() ) except FloatingPointError: # re-run the forward and backward pass with hooks attached to print # out where it fails self.zero_grad() with NanDetector(self.get_model()): for _, sample in enumerate(samples): sample, _ = self._prepare_sample(sample) self.task.train_step( sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), ignore_grad=False, ) raise except OverflowError as e: overflow = True logger.info("NOTE: overflow detected, " + str(e)) grad_norm = torch.tensor(0.0).cuda() self.zero_grad() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step if hasattr(self.model, "perform_additional_optimizer_actions"): if hasattr(self.optimizer, "fp32_params"): self.model.perform_additional_optimizer_actions( self.optimizer.optimizer, self.optimizer.fp32_params ) else: self.model.perform_additional_optimizer_actions( self.optimizer.optimizer ) logging_output = None if ( not overflow or self.cfg.distributed_training.distributed_wrapper == "SlowMo" ): self.set_num_updates(self.get_num_updates() + 1) if self.tpu: # mark step on TPUs import torch_xla.core.xla_model as xm xm.mark_step() # only log stats every log_interval steps # this causes wps to be misreported when log_interval > 1 logging_output = {} if self.get_num_updates() % self.cfg.common.log_interval == 0: # log memory usage mem_info = xm.get_memory_info(self.device) gb_free = mem_info["kb_free"] / 1024 / 1024 gb_total = mem_info["kb_total"] / 1024 / 1024 metrics.log_scalar( "gb_free", gb_free, priority=1500, round=1, weight=0, ) metrics.log_scalar( "gb_total", gb_total, priority=1600, round=1, weight=0, ) logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # log whenever there's an XLA compilation, since these # slow down training and may indicate opportunities for # optimization self._check_xla_compilation() else: # log stats logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # clear CUDA cache to reduce memory fragmentation if ( self.cuda and self.cfg.common.empty_cache_freq > 0 and ( (self.get_num_updates() + self.cfg.common.empty_cache_freq - 1) % self.cfg.common.empty_cache_freq ) == 0 ): torch.cuda.empty_cache() if self.cfg.common.fp16: metrics.log_scalar( "loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=4, weight=0, ) metrics.log_stop_time("train_wall") return logging_output