def all_reduce(self, overflow_buf, accum=1): scaler = amp.scaler.LossScaler(1.0) # 1. allocate an uninitialized buffer for flattened gradient master_grads = [ p.grad for p in amp.master_params(self.optimizer) if p.grad is not None ] flat_grad_size = sum(p.numel() for p in master_grads) allreduce_dtype = torch.float32 flat_raw = torch.empty(flat_grad_size, device='cuda', dtype=allreduce_dtype) # 2. combine unflattening and predivision of unscaled 'raw' gradient allreduced_views = apex_C.unflatten(flat_raw, master_grads) overflow_buf.zero_() amp_C.multi_tensor_scale( 65536, overflow_buf, [master_grads, allreduced_views], scaler.loss_scale() / (self.team_size * accum)) # 3. sum gradient across ranks. Because of the predivision, # this averages the gradient torch.distributed.all_reduce(flat_raw, group=self.local_group) # 4. combine unscaling and unflattening of allreduced gradient overflow_buf.zero_() amp_C.multi_tensor_scale(65536, overflow_buf, [allreduced_views, master_grads], 1. / scaler.loss_scale())
def take_optimizer_step(args, optimizer, model, overflow_buf, global_step): global skipped_steps if args.allreduce_post_accumulation: # manually allreduce gradients after all accumulation steps # check for Inf/NaN # 1. allocate an uninitialized buffer for flattened gradient loss_scale = _amp_state.loss_scalers[0].loss_scale() if args.fp16 else 1 master_grads = [p.grad for p in amp.master_params(optimizer) if p.grad is not None] flat_grad_size = sum(p.numel() for p in master_grads) allreduce_dtype = torch.float16 if args.allreduce_post_accumulation_fp16 else torch.float32 flat_raw = torch.empty(flat_grad_size, device='cuda', dtype=allreduce_dtype) # 2. combine unflattening and predivision of unscaled 'raw' gradient allreduced_views = apex_C.unflatten(flat_raw, master_grads) overflow_buf.zero_() amp_C.multi_tensor_scale(65536, overflow_buf, [master_grads, allreduced_views], loss_scale / (get_world_size() * args.gradient_accumulation_steps)) # 3. sum gradient across ranks. Because of the predivision, this averages the gradient torch.distributed.all_reduce(flat_raw) # 4. combine unscaling and unflattening of allreduced gradient overflow_buf.zero_() amp_C.multi_tensor_scale(65536, overflow_buf, [allreduced_views, master_grads], 1./loss_scale) # 5. update loss scale if args.fp16: scaler = _amp_state.loss_scalers[0] old_overflow_buf = scaler._overflow_buf scaler._overflow_buf = overflow_buf had_overflow = scaler.update_scale() scaler._overfloat_buf = old_overflow_buf else: had_overflow = 0 # 6. call optimizer step function if had_overflow == 0: optimizer.step() global_step += 1 else: # Overflow detected, print message and clear gradients skipped_steps += 1 if is_main_process(): scaler = _amp_state.loss_scalers[0] dllogger.log(step="PARAMETER", data={"loss_scale": scaler.loss_scale()}) if _amp_state.opt_properties.master_weights: for param in optimizer._amp_stash.all_fp32_from_fp16_params: param.grad = None for param in model.parameters(): param.grad = None else: if args.apply_optimizer > 0: optimizer.step() # optimizer.zero_grad() for param in model.parameters(): param.grad = None global_step += 1 return global_step
def _step_distributed_fp16(self) -> None: # manually allreduce gradients after all accumulation steps # check for Inf/NaN # 1. allocate an uninitialized buffer for flattened gradient scaler = _amp_state.loss_scalers[0] master_grads = [ p.grad for p in amp.master_params(self.optimizer) if p.grad is not None ] flat_grad_size = sum(p.numel() for p in master_grads) # allreduce_dtype = torch.float16 if args.allreduce_post_accumulation_fp16 else \ # torch.float32 allreduce_dtype = torch.float16 flat_raw = torch.empty(flat_grad_size, device='cuda', dtype=allreduce_dtype) # 2. combine unflattening and predivision of unscaled 'raw' gradient allreduced_views = apex_C.unflatten(flat_raw, master_grads) self._overflow_buf.zero_() amp_C.multi_tensor_scale( 65536, self._overflow_buf, [master_grads, allreduced_views], scaler.loss_scale() / (torch.distributed.get_world_size() * self.gradient_accumulation_steps)) # 3. sum gradient across ranks. Because of the predivision, this averages the gradient torch.distributed.all_reduce(flat_raw) # 4. combine unscaling and unflattening of allreduced gradient self._overflow_buf.zero_() amp_C.multi_tensor_scale(65536, self._overflow_buf, [allreduced_views, master_grads], 1. / scaler.loss_scale()) # 5. update loss scale scaler = _amp_state.loss_scalers[0] old_overflow_buf = scaler._overflow_buf scaler._overflow_buf = self._overflow_buf had_overflow = scaler.update_scale() scaler._overfloat_buf = old_overflow_buf # 6. call optimizer step function if had_overflow == 0: self._step() else: # Overflow detected, print message and clear gradients logger.info( f"Gradient overflow. Skipping step, reducing loss scale to " f"{scaler.loss_scale()}") if _amp_state.opt_properties.master_weights: for param in self.optimizer._amp_stash.all_fp32_from_fp16_params: param.grad = None for param in self.model.parameters(): param.grad = None
def take_optimizer_step(args, optimizer, grad_scaler, model, overflow_buf, global_step): global skipped_steps if args.allreduce_post_accumulation: # manually allreduce gradients after all accumulation steps # check for Inf/NaN # 1. allocate an uninitialized buffer for flattened gradient loss_scale = grad_scaler._get_scale_async() if args.fp16 else 1. master_grads = [ p.grad for p in model.parameters() if p.grad is not None ] flat_grad_size = sum(p.numel() for p in master_grads) allreduce_dtype = torch.float16 if args.allreduce_post_accumulation_fp16 else torch.float32 flat_raw = torch.empty(flat_grad_size, device='cuda', dtype=allreduce_dtype) # 2. combine unflattening and predivision of unscaled 'raw' gradient allreduced_views = apex_C.unflatten(flat_raw, master_grads) overflow_buf.zero_() amp_C.multi_tensor_scale( 65536, overflow_buf, [master_grads, allreduced_views], loss_scale / (get_world_size() * args.gradient_accumulation_steps)) # 3. sum gradient across ranks. Because of the predivision, this averages the gradient torch.distributed.all_reduce(flat_raw) # 4. combine unscaling and unflattening of allreduced gradient overflow_buf.zero_() amp_C.multi_tensor_scale(65536, overflow_buf, [allreduced_views, master_grads], 1.) # 5. update loss scale if args.fp16: had_overflow = overflow_buf.item() else: had_overflow = 0 # 6. call optimizer step function if had_overflow == 0: global_step += 1 else: # Overflow detected, print message and clear gradients skipped_steps += 1 else: global_step += 1 grad_scaler.step(optimizer) grad_scaler.update() optimizer.zero_grad(set_to_none=True) return global_step