def update(self, new_scale=None): """ Updates the scale factor. If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, the scale is multiplied by ``growth_factor`` to increase it. Passing ``new_scale`` sets the scale directly. Arguments: new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. .. warning:: :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has been invoked for all optimizers used this iteration. """ if not self._enabled: return self._check_scale_growth_tracker("update") if new_scale is not None: # Accept a new user-defined scale. if isinstance(new_scale, float): self._scale = torch.full((1,), new_scale, dtype=torch.float32, device=self._scale.device) else: reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." assert isinstance(new_scale, torch.cuda.FloatTensor), reason assert new_scale.numel() == 1, reason assert new_scale.requires_grad is False, reason self._scale = new_scale else: # Consume shared inf/nan data collected from optimizers to update the scale. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. found_infs = [found_inf.to(device=self._scale.device, non_blocking=True) for state in self._per_optimizer_states.values() for found_inf in state["found_inf_per_device"].values()] assert len(found_infs) > 0, "No inf checks were recorded prior to update." found_inf_combined = found_infs[0] if len(found_infs) > 1: for i in range(1, len(found_infs)): found_inf_combined += found_infs[i] self._scale = torch._amp_update_scale(self._growth_tracker, self._scale, found_inf_combined, self._growth_factor, self._backoff_factor, self._growth_interval) # To prepare for next iteration, clear the data collected from optimizers this iteration. self._per_optimizer_states = defaultdict(lambda: {"stage": self.READY, "found_inf_per_device": {}})
def update(self, new_scale: Optional[Union[float, FloatTensor]] = None) -> None: """ Updates the scale factor. If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, the scale is multiplied by ``growth_factor`` to increase it. Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not used directly, it's used to fill GradScaler's internal scale tensor. So if ``new_scale`` was a tensor, later in-place changes to that tensor will not further affect the scale GradScaler uses internally.) Args: new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. .. warning:: :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has been invoked for all optimizers used this iteration. """ if not self._enabled: return _scale, _growth_tracker = self._check_scale_growth_tracker("update") # type: ignore if new_scale is not None: # Accept a new user-defined scale. if isinstance(new_scale, float): self._scale.fill_(new_scale) else: reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] assert new_scale.numel() == 1, reason assert new_scale.requires_grad is False, reason self._scale.copy_(new_scale) else: # Consume shared inf/nan data collected from optimizers to update the scale. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. found_infs = [ found_inf.to(device=_scale.device, non_blocking=True) for state in self._per_optimizer_states.values() for found_inf in state["found_inf_per_device"].values() ] assert len(found_infs) > 0, "No inf checks were recorded prior to update." found_inf_combined = found_infs[0] if len(found_infs) > 1: for i in range(1, len(found_infs)): found_inf_combined += found_infs[i] if _scale.device.type == "cpu": self._amp_update_scale_cpu_(found_inf_combined) # type: ignore else: if torch_version() >= (1, 9, 0): torch._amp_update_scale_( # type: ignore self._scale, self._growth_tracker, found_inf_combined, self._growth_factor, # type: ignore self._backoff_factor, # type: ignore self._growth_interval, # type: ignore ) elif torch_version() >= (1, 8, 0) and torch_version() < (1, 9, 0): self._scale = torch._amp_update_scale( # type: ignore self._growth_tracker, _scale, found_inf_combined, self._growth_factor, # type: ignore self._backoff_factor, # type: ignore self._growth_interval, # type: ignore ) # To prepare for next iteration, clear the data collected from optimizers this iteration. self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)