Beispiel #1
0
    def _merge_grad_update(
        self,
        gradients: Union[Tensor, List[Tensor]],
        deferred: Optional[Dict[str, List[Callable[[],
                                                   None]]]] = None) -> None:
        """Accumulate gradients and update the model at certain frequency of invocation.

        Args:
            gradients: Input gradients.
            deferred: A dictionary in which model update functions are stored.
        """

        # add current gradient to the cumulative gradient
        for gs, g in zip(self.grad_sum, gradients):
            self._assign_add(gs, g)

        self._assign_add(self.step, 1)

        if self.step % self.merge_grad == 0:
            average_grad = [gs / self.merge_grad for gs in self.grad_sum]
            update_model(model=self.model,
                         gradients=average_grad,
                         defer=self.defer,
                         deferred=deferred)
            for gs in self.grad_sum:
                self._assign_add(gs, -gs)  # zero the gradient in place
Beispiel #2
0
 def forward(self, data: Union[Tensor, List[Tensor]], state: Dict[str, Any]) -> None:
     if not state["warmup"]:
         if self.weight_decay:
             data = data + tf.reduce_sum(self.model.losses)
         update_model(self.model,
                      data,
                      tape=state['tape'],
                      retain_graph=self.retain_graph,
                      scaler=state["scaler"],
                      defer=self.defer,
                      deferred=state["deferred"])
Beispiel #3
0
    def forward(self, data: Union[Tensor, List[Tensor]],
                state: Dict[str, Any]) -> None:
        if state["warmup"]:
            return

        if self.gradients is None:  # data is loss
            loss = self._loss_preprocess(data)
            gradients = self._get_gradient(loss, state["tape"])
        else:  # data is gradients
            gradients = data
        gradients = self._gradient_postprocess(gradients)

        if self.merge_grad > 1:
            self._merge_grad_update(gradients, deferred=state["deferred"])
        else:
            update_model(model=self.model,
                         gradients=gradients,
                         defer=self.defer,
                         deferred=state["deferred"])