def update(self, new_scale=None): """ Updates the scale factor. If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, the scale is multiplied by ``growth_factor`` to increase it. Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not used directly, it's used to fill GradScaler's internal scale tensor. So if ``new_scale`` was a tensor, later in-place changes to that tensor will not further affect the scale GradScaler uses internally.) Args: new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. .. warning:: :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has been invoked for all optimizers used this iteration. """ if not self._enabled: return _scale, _growth_tracker = self._check_scale_growth_tracker("update") if new_scale is not None: # Accept a new user-defined scale. if isinstance(new_scale, float): self._scale.fill_(new_scale) # type: ignore[union-attr] else: reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." assert isinstance(new_scale, torch.cuda.FloatTensor ), reason # type: ignore[attr-defined] assert new_scale.numel() == 1, reason assert new_scale.requires_grad is False, reason self._scale.copy_(new_scale) # type: ignore[union-attr] else: # Consume shared inf/nan data collected from optimizers to update the scale. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. found_infs = [ found_inf.to(device=_scale.device, non_blocking=True) for state in self._per_optimizer_states.values() for found_inf in state["found_inf_per_device"].values() ] assert len( found_infs) > 0, "No inf checks were recorded prior to update." found_inf_combined = found_infs[0] if len(found_infs) > 1: for i in range(1, len(found_infs)): found_inf_combined += found_infs[i] torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor, self._backoff_factor, self._growth_interval) # To prepare for next iteration, clear the data collected from optimizers this iteration. self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def update(self): """ Update the scale factor and clear information about infinities. This method should be called after each optimization step. """ if not self._enabled: return torch._amp_update_scale_( self._scale, self._growth_tracker, self._found_inf, self._growth_factor, self._backoff_factor, self._growth_interval, ) # Clear infinity found status self._found_inf = torch.zeros_like(self._found_inf)
def update(self, new_scale=None): """ Updates to native grad scaler update function. 1. Check inf across model-parallel ranks. 2. Update hysteresis tracker. 3. Apply hysteresis to grad scale update. """ if not self._enabled: return _scale, _growth_tracker = self._check_scale_growth_tracker("update") if new_scale is not None: # Accept a new user-defined scale. if isinstance(new_scale, float): self._scale.fill_(new_scale) # type: ignore[union-attr] else: reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." assert isinstance(new_scale, torch.cuda.FloatTensor ), reason # type: ignore[attr-defined] assert new_scale.numel() == 1, reason assert new_scale.requires_grad is False, reason self._scale.copy_(new_scale) # type: ignore[union-attr] else: # Consume shared inf/nan data collected from optimizers to update the scale. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. found_infs = [ found_inf.to(device=_scale.device, non_blocking=True) for state in self._per_optimizer_states.values() for found_inf in state["found_inf_per_device"].values() ] assert len( found_infs) > 0, "No inf checks were recorded prior to update." found_inf_combined = found_infs[0] # Update across all model parallel instances. torch.distributed.all_reduce( found_inf_combined, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()) if len(found_infs) > 1: for i in range(1, len(found_infs)): found_inf = found_infs[i] # Update across all model parallel instances. torch.distributed.all_reduce( found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()) found_inf_combined += found_inf if found_inf_combined > 0: self._hysteresis_tracker -= 1 if self._hysteresis_tracker <= 0: # When hysteresis becomes zero, follow the native grad scale update rule. # Increase scale and reset growth tracker torch._amp_update_scale_( _scale, _growth_tracker, found_inf_combined, self._growth_factor, self._backoff_factor, self._growth_interval, ) else: # Only reset the growth tracker when hysteresis is larger than zero _growth_tracker.fill_(0.0) else: # When no inf found, follow the native grad scale update rule. # Increment growth_tracker, update scale when growth tracker reaches the interval, and # reset the hysteresis tracker. torch._amp_update_scale_( _scale, _growth_tracker, found_inf_combined, self._growth_factor, self._backoff_factor, self._growth_interval, ) self._hysteresis_tracker = self.hysteresis # To prepare for next iteration, clear the data collected from optimizers this iteration. self._per_optimizer_states = defaultdict( torch.cuda.amp.grad_scaler._refresh_per_optimizer_state)