def _fast_stat_sync_sum( self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False, ): """ Sync logging outputs across workers. fast_stat_sync_sum is faster than all_gather_list_sync, but is only suitable when logging outputs are scalars and can be summed. Note that *logging_outputs* cannot contain any nested dicts/lists. """ data = {} for i, stat in enumerate(extra_stats_to_sum): data['extra_stats_' + str(i)] = stat if len(logging_outputs) > 0: log_keys = list(logging_outputs[0].keys()) for k in log_keys: if not ignore: v = sum(log[k] for log in logging_outputs if k in log) else: v = logging_outputs[0][k] v = torch.zeros_like(v) if torch.is_tensor(v) else 0 data['logging_outputs_' + k] = v else: log_keys = None data = distributed_utils.all_reduce_dict( data, device=self.device, group=self.data_parallel_process_group) extra_stats_to_sum = [ data['extra_stats_' + str(i)] for i in range(len(extra_stats_to_sum)) ] if log_keys is not None: logging_outputs = [{ k: data['logging_outputs_' + k] for k in log_keys }] else: logging_outputs = [] return logging_outputs, extra_stats_to_sum
def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch == "DUMMY": self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 accumulate_step = len(samples) sample_size_mask = 0 sample_size_decode = 0 if hasattr(self.task, 'count_sample_size'): #print('count size before') #print(samples) for i, sample in enumerate(samples): if sample is None or len(sample) == 0: sample = self._dummy_batch masked_tokens = sample['target'].ne(self.criterion.padding_idx) sample_size_mask_t = masked_tokens.int().sum() decode_tokens = sample['decode_target'].ne( self.criterion.padding_idx) sample_size_decode_t = decode_tokens.int().sum() sample_size_mask += sample_size_mask_t sample_size_decode += sample_size_decode_t data = {} data['sample_size_mask'] = sample_size_mask data['sample_size_decode'] = sample_size_decode data = distributed_utils.all_reduce_dict( data, device=self.device, group=self.data_parallel_process_group) sample_size_mask = data['sample_size_mask'] sample_size_decode = data['sample_size_decode'] #print('???',sample_size_decode) for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: is_dummy_batch = False sample['sample_size_mask'] = sample_size_mask sample['sample_size_decode'] = sample_size_decode sample[ 'accumulate_step'] = accumulate_step * self.data_parallel_world_size def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if (self.data_parallel_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, criterion=self.criterion, optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, ) del loss logging_outputs.append(logging_output) sample_size += sample_size_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() if self.cuda: torch.cuda.empty_cache() if self.args.distributed_world_size == 1: return None else: raise e if self.tpu and i < len(samples) - 1: # tpu-comment: every XLA operation before marking step is # appended to the IR graph, and processing too many batches # before marking step can lead to OOM errors. # To handle gradient accumulation use case, we explicitly # mark step here for every forward pass without a backward pass import torch_xla.core.xla_model as xm xm.mark_step() if is_dummy_batch: if torch.is_tensor(sample_size): sample_size.zero_() else: sample_size *= 0. if torch.is_tensor(sample_size): sample_size = sample_size.float() else: sample_size = float(sample_size) # gather logging outputs from all replicas if self._sync_stats(): train_time = self._local_cumulative_training_time() logging_outputs, ( sample_size, ooms, total_train_time) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch, ) self._cumulative_training_time = total_train_time / self.data_parallel_world_size overflow = False try: if self.tpu and self.data_parallel_world_size > 1: import torch_xla.core.xla_model as xm gradients = xm._fetch_gradients(self.optimizer.optimizer) xm.all_reduce('sum', gradients, scale=1.0 / self.data_parallel_world_size) with torch.autograd.profiler.record_function("multiply-grads"): #print('????sample_size: ',sample_size) # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). if not self.args.use_bmuf: self.optimizer.multiply_grads( self.data_parallel_world_size / sample_size) elif sample_size > 0: # BMUF needs to check sample size num = self.data_parallel_world_size if self._sync_stats( ) else 1 self.optimizer.multiply_grads(num / sample_size) with torch.autograd.profiler.record_function("clip-grads"): # clip grads grad_norm = self.clip_grad_norm(self.args.clip_norm) # check that grad norms are consistent across workers if (not self.args.use_bmuf and self.args.distributed_wrapper != 'SlowMo' and not self.tpu): self._check_grad_norms(grad_norm) with torch.autograd.profiler.record_function("optimizer"): # take an optimization step self.optimizer.step() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print # out where it fails with NanDetector(self.model): self.task.train_step(sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), ignore_grad=False) raise except OverflowError as e: overflow = True logger.info("NOTE: overflow detected, " + str(e)) grad_norm = torch.tensor(0.).cuda() self.zero_grad() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step if hasattr(self.model, 'perform_additional_optimizer_actions'): if hasattr(self.optimizer, 'fp32_params'): self.model.perform_additional_optimizer_actions( self.optimizer.optimizer, self.optimizer.fp32_params) else: self.model.perform_additional_optimizer_actions( self.optimizer.optimizer) if not overflow or self.args.distributed_wrapper == 'SlowMo': self.set_num_updates(self.get_num_updates() + 1) if self.tpu: # mark step on TPUs import torch_xla.core.xla_model as xm xm.mark_step() # only log stats every log_interval steps # this causes wps to be misreported when log_interval > 1 logging_output = {} if self.get_num_updates() % self.args.log_interval == 0: logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # log whenever there's an XLA compilation, since these # slow down training and may indicate opportunities for # optimization self._check_xla_compilation() else: # log stats logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # clear CUDA cache to reduce memory fragmentation if (self.cuda and self.args.empty_cache_freq > 0 and ( (self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq) == 0): torch.cuda.empty_cache() if self.args.fp16: metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) metrics.log_stop_time("train_wall") return logging_output