def _merge_grad_update( self, gradients: Union[Tensor, List[Tensor]], deferred: Optional[Dict[str, List[Callable[[], None]]]] = None) -> None: """Accumulate gradients and update the model at certain frequency of invocation. Args: gradients: Input gradients. deferred: A dictionary in which model update functions are stored. """ # add current gradient to the cumulative gradient for gs, g in zip(self.grad_sum, gradients): self._assign_add(gs, g) self._assign_add(self.step, 1) if self.step % self.merge_grad == 0: average_grad = [gs / self.merge_grad for gs in self.grad_sum] update_model(model=self.model, gradients=average_grad, defer=self.defer, deferred=deferred) for gs in self.grad_sum: self._assign_add(gs, -gs) # zero the gradient in place
def forward(self, data: Union[Tensor, List[Tensor]], state: Dict[str, Any]) -> None: if not state["warmup"]: if self.weight_decay: data = data + tf.reduce_sum(self.model.losses) update_model(self.model, data, tape=state['tape'], retain_graph=self.retain_graph, scaler=state["scaler"], defer=self.defer, deferred=state["deferred"])
def forward(self, data: Union[Tensor, List[Tensor]], state: Dict[str, Any]) -> None: if state["warmup"]: return if self.gradients is None: # data is loss loss = self._loss_preprocess(data) gradients = self._get_gradient(loss, state["tape"]) else: # data is gradients gradients = data gradients = self._gradient_postprocess(gradients) if self.merge_grad > 1: self._merge_grad_update(gradients, deferred=state["deferred"]) else: update_model(model=self.model, gradients=gradients, defer=self.defer, deferred=state["deferred"])