def _finish_update(self): if self.training_config.clip_gradients: clip_gradients( self.model, self.num_updates, self.logistics_callback.tb_writer, self.config, scale=self.scaler.get_scale(), ) if is_xla(): import torch_xla.core.xla_model as xm # Assumes no model parallel xm.reduce_gradients(self.optimizer) self.scaler.step(self.optimizer) self.scaler.update() self.num_updates += 1 self.profile("Finished update")
def train_step(self, samples, dummy_batch=False, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch is None: self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() if not dummy_batch: self.meters['train_wall'].start() # forward and backward pass logging_outputs, sample_sizes, ooms = [], [], 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) ignore_grad = True else: ignore_grad = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if (self.args.distributed_world_size > 1 and hasattr(self.model, 'no_sync') and i < len(samples) - 1): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size, logging_output = self.task.train_step( sample, self.model, self.criterion, self.optimizer, ignore_grad) if not ignore_grad: logging_outputs.append(logging_output) sample_sizes.append(sample_size) if self.fast_stat_sync: self._all_reduce_list[0] += sample_size self._all_reduce_list[1] += logging_output.get( 'nsentences', 0.0) self._all_reduce_list[2] += logging_output.get( 'loss', 0.0) self._all_reduce_list[3] += logging_output.get( 'nll_loss', 0.0) self._all_reduce_list[4] += logging_output.get( 'ntokens', 0.0) except RuntimeError as e: if 'out of memory' in str(e): msg = ('| WARNING: ran out of memory with exception: ' + '{};'.format(e) + '\n Skipping batch') # TODO: print should really go to logger, this print goes # to stderr, which is buffered, which in many cases is not # printed out if another exception happens. # NB(jerry): added a flush to mitigate this print(msg, file=sys.stderr) if torch.cuda.is_available() and hasattr( torch.cuda, "memory_summary"): for device_idx in range(torch.cuda.device_count()): print(torch.cuda.memory_summary(device=device_idx), file=sys.stderr) sys.stderr.flush() if raise_oom: raise ValueError(msg) ooms += 1 self.zero_grad() else: raise e if self.xla and len(samples) > 1: # tpu-comment: every xla operation before marking step is # appended to the IR graph, and processing too many batches # before marking step can lead to OOM errors. # To handle gradient accumulation use case, we explicitly # mark step here for every forward pass if we accumulate grads xm.mark_step() if self.fast_stat_sync: self._all_reduce_list[5] += ooms if ooms > 0 and self._oom_batch is not None: self.handle_ooms(ooms) if dummy_batch: return None # gather logging outputs from all replicas if self.fast_stat_sync: # rework all_gather_list all_reduce_list_tensor = torch.cuda.DoubleTensor( self._all_reduce_list) if self._sync_stats(): torch.distributed.all_reduce(all_reduce_list_tensor) # Normalize loss and nll_loss by "sample_size" # and convert to log base 2 all_reduce_list_tensor[2:4].div_( (all_reduce_list_tensor[0:1] * torch.log(torch.cuda.DoubleTensor([2])))) self._all_reduce_list = all_reduce_list_tensor.tolist() logging_output = {} [ sample_size, logging_output['nsentences'], logging_output['loss'], logging_output['nll_loss'], logging_output['ntokens'], ooms, ] = self._all_reduce_list elif self._sync_stats(): logging_outputs, sample_sizes, ooms, prev_norms = \ zip(*distributed_utils.all_gather_list( [logging_outputs, sample_sizes, ooms, self._prev_grad_norm], )) logging_outputs = list(chain.from_iterable(logging_outputs)) sample_sizes = list(chain.from_iterable(sample_sizes)) ooms = sum(ooms) if not self.args.use_bmuf: assert ( all(norm == prev_norms[0] for norm in prev_norms) or all( math.isnan(norm) or math.isinf(norm) for norm in prev_norms) ), 'Fatal error: gradients are inconsistent between workers' self.meters['oom'].update(ooms, len(samples)) if ooms == self.args.distributed_world_size * len(samples): print('| WARNING: OOM in all workers, skipping update') self.zero_grad() return None if not self.fast_stat_sync: # aggregate logging outputs and sample sizes logging_output = self.task.aggregate_logging_outputs( logging_outputs, self.get_criterion()) sample_size = self.task.grad_denom(sample_sizes, self.get_criterion()) if not all(k in logging_output for k in ['ntokens', 'nsentences']): raise Exception( ('Please update the {}.aggregate_logging_outputs() method to ' 'return ntokens and nsentences').format( self.task.__class__.__name__)) try: # normalize grads by sample size if sample_size > 0: self.optimizer.multiply_grads( self.args.distributed_world_size / float(sample_size)) if self.xla: # tpu-comment: for tpu training, we need to explicitly reduce # gradients here xm.reduce_gradients(self.optimizer) # clip grads grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) self._prev_grad_norm = grad_norm # take an optimization step self.optimizer.step() self.set_num_updates(self.get_num_updates() + 1) # task specific update per step self.task.update_step(self._num_updates) # update meters ntokens = logging_output.get('ntokens', 0) nsentences = logging_output.get('nsentences', 0) self.meters['wps'].update(ntokens) self.meters['ups'].update(1.) self.meters['wpb'].update(ntokens) self.meters['bsz'].update(nsentences) self.meters['gnorm'].update(grad_norm) # tpu-comment: the comparison below introduces too many .item() # calls and slows down tpu self.meters['clip'].update( 0. #1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0. ) self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size) if 'train_acc' in self.meters: self.meters['train_acc'].update(logging_output.get('acc', 0), sample_size) if 'nll_loss' in logging_output: self.meters['train_nll_loss'].update( logging_output.get('nll_loss', 0), ntokens) # clear CUDA cache to reduce memory fragmentation if (self.args.empty_cache_freq > 0 and ((self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq) == 0 and torch.cuda.is_available() and not self.args.cpu): torch.cuda.empty_cache() except OverflowError as e: print('| WARNING: overflow detected, ' + str(e)) self.zero_grad() logging_output = None if self.args.fp16: self.meters['loss_scale'].reset() self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale) self.clear_buffered_stats() self.meters['train_wall'].stop() return logging_output