def _finish_update(): clip_gradients(mmf_trainer.model, mmf_trainer.num_updates, None, mmf_trainer.config) for param in mmf_trainer.model.parameters(): mmf_grad = torch.clone(param.grad).detach().item() self.mmf_grads.append(mmf_grad) mmf_trainer.scaler.step(mmf_trainer.optimizer) mmf_trainer.scaler.update() mmf_trainer.num_updates += 1
def _backward(self, loss): self.optimizer.zero_grad() loss.backward() if self.should_clip_gradients: clip_gradients(self.model, self.num_updates, self.tb_writer, self.config) self.optimizer.step() self._run_scheduler() self.num_updates += 1 self.profile("Backward time")
def _finish_update(self): if self.training_config.clip_gradients: clip_gradients( self.model, self.num_updates, self.logistics_callback.tb_writer, self.config, ) self.optimizer.step() self.num_updates += 1 self.profile("Finished update")
def _backward(self, loss: Tensor) -> None: self.optimizer.zero_grad() loss.backward() if self.training_config.clip_gradients: clip_gradients( self.model, self.num_updates, self.logistics_callback.tb_writer, self.config, ) self.optimizer.step() self.num_updates += 1 self.profile("Backward time")
def _finish_update(self): if self.training_config.clip_gradients: clip_gradients( self.model, self.num_updates, self.logistics_callback.tb_writer, self.config, scale=self.scaler.get_scale(), ) if is_xla(): import torch_xla.core.xla_model as xm # Assumes no model parallel xm.reduce_gradients(self.optimizer) self.scaler.step(self.optimizer) self.scaler.update() self.num_updates += 1 self.profile("Finished update")