def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch == "DUMMY": self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: is_dummy_batch = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if (self.args.distributed_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, criterion=self.criterion, optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, ) del loss logging_outputs.append(logging_output) if not is_dummy_batch: sample_size += sample_size_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() else: raise e # gather logging outputs from all replicas if self._sync_stats(): logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, ignore=is_dummy_batch, ) metrics.log_scalar("oom", ooms, len(samples), priority=600, round=3) if ooms == self.args.distributed_world_size * len(samples): logger.warning("OOM in all workers, skipping update") self.zero_grad() return None try: # normalize grads by sample size if sample_size > 0: if self._sync_stats(): # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). self.optimizer.multiply_grads( self.args.distributed_world_size / sample_size) else: self.optimizer.multiply_grads(1 / sample_size) # clip grads grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) # check that grad norms are consistent across workers if not self.args.use_bmuf: self._check_grad_norms(grad_norm) # take an optimization step self.optimizer.step() self.set_num_updates(self.get_num_updates() + 1) # log stats logging_output = self._reduce_and_log_stats( logging_outputs, sample_size) metrics.log_speed("ups", 1., ignore_first=10, priority=100, round=2) metrics.log_scalar("gnorm", utils.item(grad_norm), priority=400, round=3) metrics.log_scalar( "clip", 100 if grad_norm > self.args.clip_norm > 0 else 0, priority=500, round=1, ) # clear CUDA cache to reduce memory fragmentation if (self.args.empty_cache_freq > 0 and ((self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq) == 0 and torch.cuda.is_available() and not self.args.cpu): torch.cuda.empty_cache() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print out where it fails with NanDetector(self.model): self.task.train_step(sample, self.model, self.criterion, self.optimizer, ignore_grad=False) raise except OverflowError as e: logger.info("NOTE: overflow detected, " + str(e)) self.zero_grad() logging_output = None except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e if self.args.fp16: metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) metrics.log_stop_time("train_wall") return logging_output
def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" self._set_seed() self.model.train() self.criterion.train() self.zero_grad() metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: if self._dummy_batch == "DUMMY": self._dummy_batch = sample is_dummy_batch = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if (self.data_parallel_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, criterion=self.criterion, optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, ) del loss logging_outputs.append(logging_output) sample_size += sample_size_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() if self.cuda: torch.cuda.empty_cache() if self.cfg.distributed_training.distributed_world_size == 1: return None else: raise e if self.tpu and i < len(samples) - 1: # tpu-comment: every XLA operation before marking step is # appended to the IR graph, and processing too many batches # before marking step can lead to OOM errors. # To handle gradient accumulation use case, we explicitly # mark step here for every forward pass without a backward pass import torch_xla.core.xla_model as xm xm.mark_step() if is_dummy_batch: if torch.is_tensor(sample_size): sample_size.zero_() else: sample_size *= 0.0 if torch.is_tensor(sample_size): sample_size = sample_size.float() else: sample_size = float(sample_size) # gather logging outputs from all replicas if self._sync_stats(): train_time = self._local_cumulative_training_time() logging_outputs, ( sample_size, ooms, total_train_time, ) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch, ) self._cumulative_training_time = (total_train_time / self.data_parallel_world_size) overflow = False try: with torch.autograd.profiler.record_function("reduce-grads"): self.optimizer.all_reduce_grads(self.model) if utils.has_parameters(self.criterion): self.optimizer.all_reduce_grads(self.criterion) with torch.autograd.profiler.record_function("multiply-grads"): # multiply gradients by (data_parallel_size / sample_size) since # DDP already normalizes by the number of data parallel workers. # Thus we get (sum_of_gradients / sample_size) at the end. if not self.cfg.optimization.use_bmuf: self.optimizer.multiply_grads( self.data_parallel_world_size / sample_size) elif sample_size > 0: # BMUF needs to check sample size num = self.data_parallel_world_size if self._sync_stats( ) else 1 self.optimizer.multiply_grads(num / sample_size) with torch.autograd.profiler.record_function("clip-grads"): # clip grads grad_norm = self.clip_grad_norm( self.cfg.optimization.clip_norm) # check that grad norms are consistent across workers # on tpu check tensor is slow if not self.tpu: if (not self.cfg.optimization.use_bmuf and self.cfg.distributed_training.distributed_wrapper != "SlowMo"): self._check_grad_norms(grad_norm) if not torch.isfinite(grad_norm).all(): # check local gradnorm single GPU case, trigger NanDetector raise FloatingPointError("gradients are Nan/Inf") with torch.autograd.profiler.record_function("optimizer"): # take an optimization step self.optimizer.step() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print # out where it fails with NanDetector(self.get_model()): self.task.train_step( sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), ignore_grad=False, ) raise except OverflowError as e: overflow = True logger.info("NOTE: overflow detected, " + str(e)) grad_norm = torch.tensor(0.0).cuda() self.zero_grad() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step if hasattr(self.model, "perform_additional_optimizer_actions"): if hasattr(self.optimizer, "fp32_params"): self.model.perform_additional_optimizer_actions( self.optimizer.optimizer, self.optimizer.fp32_params) else: self.model.perform_additional_optimizer_actions( self.optimizer.optimizer) logging_output = None if (not overflow or self.cfg.distributed_training.distributed_wrapper == "SlowMo"): self.set_num_updates(self.get_num_updates() + 1) if self.tpu: # mark step on TPUs import torch_xla.core.xla_model as xm xm.mark_step() # only log stats every log_interval steps # this causes wps to be misreported when log_interval > 1 logging_output = {} if self.get_num_updates() % self.cfg.common.log_interval == 0: # log memory usage mem_info = xm.get_memory_info(self.device) gb_free = mem_info["kb_free"] / 1024 / 1024 gb_total = mem_info["kb_total"] / 1024 / 1024 metrics.log_scalar( "gb_free", gb_free, priority=1500, round=1, weight=0, ) metrics.log_scalar( "gb_total", gb_total, priority=1600, round=1, weight=0, ) logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # log whenever there's an XLA compilation, since these # slow down training and may indicate opportunities for # optimization self._check_xla_compilation() else: # log stats logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # clear CUDA cache to reduce memory fragmentation if (self.cuda and self.cfg.common.empty_cache_freq > 0 and ((self.get_num_updates() + self.cfg.common.empty_cache_freq - 1) % self.cfg.common.empty_cache_freq) == 0): torch.cuda.empty_cache() if self.cfg.common.fp16: metrics.log_scalar( "loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=4, weight=0, ) metrics.log_stop_time("train_wall") return logging_output
def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch == "DUMMY": self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: is_dummy_batch = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if (self.data_parallel_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, criterion=self.criterion, optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, ) del loss logging_outputs.append(logging_output) sample_size += sample_size_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() else: raise e if self.tpu and i < len(samples) - 1: # tpu-comment: every XLA operation before marking step is # appended to the IR graph, and processing too many batches # before marking step can lead to OOM errors. # To handle gradient accumulation use case, we explicitly # mark step here for every forward pass without a backward pass import torch_xla.core.xla_model as xm xm.mark_step() if is_dummy_batch: if torch.is_tensor(sample_size): sample_size.zero_() else: sample_size *= 0. if torch.is_tensor(sample_size): sample_size = sample_size.float() else: sample_size = float(sample_size) # gather logging outputs from all replicas if self._sync_stats(): logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, ignore=is_dummy_batch, ) overflow = False try: if self.tpu and self.data_parallel_world_size > 1: import torch_xla.core.xla_model as xm gradients = xm._fetch_gradients(self.optimizer.optimizer) xm.all_reduce('sum', gradients, scale=1.0 / self.data_parallel_world_size) # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). if not self.args.use_bmuf: self.optimizer.multiply_grads(self.data_parallel_world_size / sample_size) elif sample_size > 0: # BMUF needs to check sample size num = self.data_parallel_world_size if self._sync_stats( ) else 1 self.optimizer.multiply_grads(num / sample_size) # clip grads grad_norm = self.clip_grad_norm(self.args.clip_norm) # check that grad norms are consistent across workers if (not self.args.use_bmuf and self.args.distributed_wrapper != 'SlowMo' and not self.tpu): self._check_grad_norms(grad_norm) # take an optimization step self.optimizer.step() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print out where it fails with NanDetector(self.model): self.task.train_step(sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), ignore_grad=False) raise except OverflowError as e: overflow = True logger.info("NOTE: overflow detected, " + str(e)) grad_norm = torch.tensor(0.).cuda() self.zero_grad() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step if hasattr(self.model, 'perform_additional_optimizer_actions'): if hasattr(self.optimizer, 'fp32_params'): self.model.perform_additional_optimizer_actions( self.optimizer.optimizer, self.optimizer.fp32_params) else: self.model.perform_additional_optimizer_actions( self.optimizer.optimizer) if not overflow or self.args.distributed_wrapper == 'SlowMo': self.set_num_updates(self.get_num_updates() + 1) if self.tpu: # mark step on TPUs import torch_xla.core.xla_model as xm xm.mark_step() # only log stats every log_interval steps # this causes wps to be misreported when log_interval > 1 logging_output = {} if self.get_num_updates() % self.args.log_interval == 0: logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # log whenever there's an XLA compilation, since these # slow down training and may indicate opportunities for # optimization self._check_xla_compilation() else: # log stats logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # clear CUDA cache to reduce memory fragmentation if (self.cuda and self.args.empty_cache_freq > 0 and ( (self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq) == 0): torch.cuda.empty_cache() if self.args.fp16: metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) metrics.log_stop_time("train_wall") return logging_output
def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch == "DUMMY": self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 # added by Junxian lambda_stats_sum = 0 nsentences = 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: is_dummy_batch = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if (self.args.distributed_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): if is_dummy_batch: print("dummy batch!") # try: # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, criterion=self.criterion, optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, ) del loss logging_outputs.append(logging_output) sample_size += sample_size_i nsentences += logging_output['nsentences'] # except: # pass # Added by Junxian: manually update lambda (support distributed training) with torch.no_grad(): lambda_stats_sum_i = self.task.collect_lambda_stats( self.model, sample) lambda_stats_sum = lambda_stats_sum + lambda_stats_sum_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() else: raise e if is_dummy_batch: if torch.is_tensor(sample_size): sample_size.zero_() else: sample_size *= 0.0 if torch.is_tensor(sample_size): sample_size = sample_size.float() else: sample_size = float(sample_size) # gather logging outputs from all replicas if self._sync_stats(): logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, ignore=is_dummy_batch, ) try: # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). if not self.args.use_bmuf: self.optimizer.multiply_grads( self.args.distributed_world_size / sample_size) elif sample_size > 0: # BMUF needs to check sample size num = self.args.distributed_world_size if self._sync_stats( ) else 1 self.optimizer.multiply_grads(num / sample_size) # clip grads grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) # Added by Junxian: manually update lambda (support distributed training) with torch.no_grad(): if self._sync_stats(): torch.distributed.all_reduce( lambda_stats_sum, op=torch.distributed.ReduceOp.SUM) # print('nsentences {} per gpu'.format(nsentences)) nsentences_t = torch.tensor(nsentences, device=self.device) torch.distributed.all_reduce( nsentences_t, op=torch.distributed.ReduceOp.SUM) nsentences = nsentences_t.item() # print('nsentences total {}'.format(nsentences)) # TODO(junxian): is_dummy_batch might be different across GPUs and would # potentially cause lambda mismatch among GPUs self.task.distributed_update_lambda( model=self.model, lambda_stats_sum=lambda_stats_sum, nsentences=nsentences, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch) # check that grad norms are consistent across workers if not self.args.use_bmuf: self._check_grad_norms(grad_norm) # added by Junxian to check some manually updated params # self._check_grad_norms(torch.tensor([self.model._lambda_t], device=self.device)) self._check_grad_norms(self.model.get_lambda().max()) # take an optimization step self.optimizer.step() self.set_num_updates(self.get_num_updates() + 1) # log stats logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # clear CUDA cache to reduce memory fragmentation if (self.args.empty_cache_freq > 0 and ((self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq) == 0 and torch.cuda.is_available() and not self.args.cpu): torch.cuda.empty_cache() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print out where it fails with NanDetector(self.model): self.task.train_step(sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), ignore_grad=False) raise except OverflowError as e: logger.info("NOTE: overflow detected, " + str(e)) self.zero_grad() logging_output = None except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e if self.args.fp16: metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) metrics.log_stop_time("train_wall") return logging_output
def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch == "DUMMY": self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: is_dummy_batch = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if ( self.data_parallel_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1 ): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, criterion=self.criterion, optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, ) del loss logging_outputs.append(logging_output) sample_size += sample_size_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() else: raise e if torch.is_tensor(sample_size): sample_size = sample_size.float() else: sample_size = float(sample_size) if is_dummy_batch: sample_size *= 0. # multiply by 0 to preserve device # gather logging outputs from all replicas if self._sync_stats(): logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, ignore=is_dummy_batch, ) overflow = False try: # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). if not self.args.use_bmuf: multiplier = self.data_parallel_world_size self.optimizer.multiply_grads( multiplier / sample_size ) elif sample_size > 0: # BMUF needs to check sample size num = self.data_parallel_world_size if self._sync_stats() else 1 self.optimizer.multiply_grads(num / sample_size) # clip grads grad_norm = self.clip_grad_norm(self.args.clip_norm) # check that grad norms are consistent across workers if not self.args.use_bmuf and self.args.distributed_wrapper != 'SlowMo': self._check_grad_norms(grad_norm) # take an optimization step self.optimizer.step() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print out where it fails with NanDetector(self.model): self.task.train_step( sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), ignore_grad=False ) raise except OverflowError as e: overflow = True logger.info("NOTE: overflow detected, " + str(e)) grad_norm = torch.tensor(0.).cuda() self.zero_grad() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step if hasattr(self.model, 'perform_additional_optimizer_actions'): if hasattr(self.optimizer, 'fp32_params'): self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params) else: self.model.perform_additional_optimizer_actions(self.optimizer.optimizer) if not overflow or self.args.distributed_wrapper == 'SlowMo': self.set_num_updates(self.get_num_updates() + 1) try: opt_ = self.optimizer.fp32_optimizer._optimizer except: opt_ = self.optimizer._optimizer states = opt_.state[opt_.param_groups[0]['params'][0]] gvar = None adam_mom2 = None gvar_diff = None xstd = None ams_mom = None acc_ratio = None real_var = None real_var_diff = None if self.args.optimizer == "varscale_sgd": adam_mom2 = torch.mean(states['g_sq_est']).item() if getattr(opt_, "adaptive_lrs", None) is not None: lr_min, lr_max, lr_median = opt_.adaptive_lrs else: lr_min, lr_max, lr_median = None, None, None if getattr(opt_, "update_size", None) is not None: update_min, update_max, update_median = opt_.update_size else: update_min, update_max, update_median = None, None, None valid_ratio = getattr(opt_, "valid_ratio", None) ad_beta = getattr(opt_, "adaptive_beta", None) var_adapt = getattr(opt_, "var_adapt", None) # log stats logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, gvar=gvar, adam_mom2=adam_mom2, gvar_diff=gvar_diff, xstd=xstd, ams_mom=ams_mom, acc_ratio=acc_ratio, real_var=real_var, real_var_diff=real_var_diff, ad_beta=ad_beta, lr_min=lr_min, lr_max=lr_max, lr_median=lr_median, update_min=update_min, update_max=update_max, update_median=update_median, valid_ratio=valid_ratio, var_adapt=var_adapt ) # clear CUDA cache to reduce memory fragmentation if ( self.args.empty_cache_freq > 0 and ( (self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq ) == 0 and torch.cuda.is_available() and not self.args.cpu ): torch.cuda.empty_cache() if self.args.fp16: metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) metrics.log_stop_time("train_wall") return logging_output
def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch == "DUMMY": self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: is_dummy_batch = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if (self.data_parallel_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, criterion=self.criterion, optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, ) del loss logging_outputs.append(logging_output) sample_size += sample_size_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() else: raise e if torch.is_tensor(sample_size): sample_size = sample_size.float() else: sample_size = float(sample_size) if is_dummy_batch: sample_size *= 0. # multiply by 0 to preserve device # gather logging outputs from all replicas if self._sync_stats(): logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, ignore=is_dummy_batch, ) try: # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). if not self.args.use_bmuf: multiplier = self.data_parallel_world_size self.optimizer.multiply_grads(multiplier / sample_size) elif sample_size > 0: # BMUF needs to check sample size num = self.data_parallel_world_size if self._sync_stats( ) else 1 self.optimizer.multiply_grads(num / sample_size) """ # simple-ln for name, param in self.model.named_parameters(): if 'layer_norm' in name: param.grad = None """ # clip grads grad_norm = self.clip_grad_norm(self.args.clip_norm) """ for name, param in self.model.named_parameters(): if param.grad == None or 'bias' in name: continue else: param.grad = param.grad.data.float() / param.grad.data.float().norm() #print(name) #print(torch.mean(torch.abs(param.grad))) """ # check that grad norms are consistent across workers if not self.args.use_bmuf: self._check_grad_norms(grad_norm) beta3 = 0.99 if self.get_num_updates() > 0: for name, param in self.model.named_parameters(): if 'decoder.layers.0.fc2.weight' in name: #layer0 = self.optimizer._optimizer.state[param]['exp_avg'].data.float().norm() layer0 = param.grad.data.float().norm() elif 'decoder.layers.5.fc2.weight' in name: #layer5 = self.optimizer._optimizer.state[param]['exp_avg'].data.float().norm() layer5 = param.grad.data.float().norm() else: pass current_ratio = layer0.item() / layer5.item() decay_ratio = current_ratio * ( 1 - beta3**self.get_num_updates()) / self.ratio #print(decay_ratio) if (decay_ratio > 1.5 or decay_ratio < 0.75) and self.get_num_updates() < 4000: self.optimizer._optimizer.decay = True #pass else: pass self.ratio = beta3 * self.ratio + (1 - beta3) * current_ratio # take an optimization step self.optimizer.step() self.set_num_updates(self.get_num_updates() + 1) self.optimizer._optimizer.decay = False """ #print(self.get_num_updates()) for name, param in self.model.named_parameters(): if 'decoder.layers.0' in name and 'fc2' in name and 'weight' in name: #exp_avg = self.optimizer._optimizer.state[param]['exp_avg'] #denom = torch.sqrt(self.optimizer._optimizer.state[param]['exp_avg_sq']) + 1e-8 #target = (math.sqrt(1 - 0.98 ** self.get_num_updates()) / (1 - 0.9 ** self.get_num_updates())) * exp_avg / denom #target = exp_avg / (1 - 0.9 ** self.get_num_updates()) #print(name) #print(torch.mean(torch.abs(target))) #print(torch.mean(torch.abs(self.optimizer._optimizer.state[param]['update_term']))) print(self.optimizer._optimizer.state[param]['ratio']) """ if self.get_num_updates() == 1: for name, param in self.model.named_parameters(): if 'decoder.layers.0.fc2.weight' in name: layer0 = self.optimizer._optimizer.state[param][ 'exp_avg'].data.float().norm() elif 'decoder.layers.5.fc2.weight' in name: layer5 = self.optimizer._optimizer.state[param][ 'exp_avg'].data.float().norm() else: pass current_ratio = layer0.item() / layer5.item() self.ratio = beta3 * self.ratio + (1 - beta3) * current_ratio # visualize lr """ for name, param in self.model.named_parameters(): if 'decoder.layers.0.fc2.weight' in name: print(self.optimizer._optimizer.state[param]['ratio']) break """ # log stats logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # clear CUDA cache to reduce memory fragmentation if (self.args.empty_cache_freq > 0 and ((self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq) == 0 and torch.cuda.is_available() and not self.args.cpu): torch.cuda.empty_cache() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print out where it fails with NanDetector(self.model): self.task.train_step(sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), ignore_grad=False) raise except OverflowError as e: logger.info("NOTE: overflow detected, " + str(e)) self.zero_grad() logging_output = None except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e if self.args.fp16: metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) metrics.log_stop_time("train_wall") return logging_output
def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch == "DUMMY": self._dummy_batch = samples[0] self._set_seed() #self.model.train() #self.criterion.train() #self.zero_grad() #print(len(samples)) #print('In ORT Train Step') metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is not None: #print ('Token Ids: ', sample['id']) ''' net_input = sample['net_input'] src_tokens = net_input['src_tokens'] print('ORT_TRAIN_STEP: src_tokens size: {}'.format(src_tokens.size())) src_lengths = net_input['src_lengths'] prev_output_tokens = net_input['prev_output_tokens'] target = sample['target'] target = target.view(-1) print('ORT_TRAIN_STEP: src_lengths size: {}'.format(src_lengths.size())) print('ORT_TRAIN_STEP: prev_output_tokens size: {}'.format(prev_output_tokens.size())) print('ORT_TRAIN_STEP: target size: {}'.format(target.size())) if (src_lengths.size(0) != 3): print('src_lengths incorrect size', src_lengths.size(0)) sample = None ''' if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: sample = self._prepare_sample(sample) is_dummy_batch = False #for key, value in sample.items(): #print('Sample key: {}'.format(key)) ''' # Visualize model model_desc = ort_supplement.bart_model_description(self.args) # example: {input0:{0:'batch'}, input1:{0:'batch'}} dynamic_axes = {} for input in model_desc.inputs_: symbolic_axis = {} for i, axis in enumerate(input.shape_): if isinstance(axis, str): symbolic_axis[i] = axis if len(symbolic_axis): dynamic_axes[input.name_] = symbolic_axis for output in model_desc.outputs_: symbolic_axis = {} for i, axis in enumerate(output.shape_): if isinstance(axis, str): symbolic_axis[i] = axis if len(symbolic_axis): dynamic_axes[output.name_] = symbolic_axis net_input = sample['net_input'] src_tokens = net_input['src_tokens'] src_lengths = net_input['src_lengths'] prev_output_tokens = net_input['prev_output_tokens'] target = sample['target'] target = target.view(-1) src_tokens.cpu() src_lengths.cpu() prev_output_tokens.cpu() target.cpu() #self._model.cuda() input_names = [input.name_ for input in model_desc.inputs_] output_names = [output.name_ for output in model_desc.outputs_] self._model.eval() with torch.no_grad(): sample_outputs = self._model(src_tokens, src_lengths, prev_output_tokens, target) if isinstance(sample_outputs, torch.Tensor): sample_outputs = [sample_outputs] for sample_output, output_desc in zip(sample_outputs, model_desc.outputs_): output_desc.dtype_ = sample_output.dtype self._model.train() import io f = io.BytesIO() # Other export options to use(this is for backward compatibility). other_export_options = {} other_export_options['training'] = True torch.onnx._export(self._model, tuple([src_tokens, src_lengths, prev_output_tokens, target]), f, input_names=input_names, output_names=output_names, opset_version=12, dynamic_axes=dynamic_axes, _retain_param_name=True, example_outputs=tuple(sample_outputs), do_constant_folding=False, **other_export_options) ''' def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if ( self.data_parallel_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1 ): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = ort_supplement.ort_train_step( self.args, update_num=self.get_num_updates(), model=self.model, sample=sample, ) del loss logging_outputs.append(logging_output) sample_size += sample_size_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() else: raise e if is_dummy_batch: if torch.is_tensor(sample_size): sample_size.zero_() else: sample_size *= 0. if torch.is_tensor(sample_size): sample_size = sample_size.float() else: sample_size = float(sample_size) # gather logging outputs from all replicas if self._sync_stats(): logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, ignore=is_dummy_batch, ) overflow = False ''' try: # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). if not self.args.use_bmuf: self.optimizer.multiply_grads(self.data_parallel_world_size / sample_size) elif sample_size > 0: # BMUF needs to check sample size num = self.data_parallel_world_size if self._sync_stats() else 1 self.optimizer.multiply_grads(num / sample_size) # clip grads grad_norm = self.clip_grad_norm(self.args.clip_norm) # check that grad norms are consistent across workers if ( not self.args.use_bmuf and self.args.distributed_wrapper != 'SlowMo' and not self.tpu ): self._check_grad_norms(grad_norm) # take an optimization step self.optimizer.step() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print # out where it fails with NanDetector(self.model): self.task.train_step( sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), ignore_grad=False ) raise except OverflowError as e: overflow = True logger.info("NOTE: overflow detected, " + str(e)) grad_norm = torch.tensor(0.).cuda() self.zero_grad() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step if hasattr(self.model, 'perform_additional_optimizer_actions'): if hasattr(self.optimizer, 'fp32_params'): self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params) else: self.model.perform_additional_optimizer_actions(self.optimizer.optimizer) ''' if not overflow or self.args.distributed_wrapper == 'SlowMo': self.set_num_updates(self.get_num_updates() + 1) # log stats logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) ''' # clear CUDA cache to reduce memory fragmentation if ( self.cuda and self.args.empty_cache_freq > 0 and ( (self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq ) == 0 ): torch.cuda.empty_cache() ''' #if self.args.fp16: #metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) metrics.log_stop_time("train_wall") return logging_output