def on_batch_end(self, state): # @TODO: check this one if not state.is_train: return self.accumulation_counter += 1 if not self.fp16: # for _, value in state._optimizer.items(): # value.zero_grad() # @TODO: check this one if len(state._optimizer) > 0: for key, value in state.loss.items(): value.backward() if (self.accumulation_counter + 1) \ % self.accumulation_steps == 0: self.grad_step(state._optimizer) state.model.zero_grad() self.accumulation_counter = 0 else: state.model.zero_grad() if len(state._optimizer) > 0: assert len(state._optimizer) == 1, \ "fp16 mode works only with one optimizer for now" for key, value in state.loss.items(): scaled_loss = self.fp16_grad_scale * value.float() scaled_loss.backward() master_params = list( state._optimizer["main"].param_groups[0]["params"]) model_params = list( filter(lambda p: p.requires_grad, state.model.parameters())) copy_grads(source=model_params, target=master_params) for param in master_params: param.grad.data.mul_(1. / self.fp16_grad_scale) self.grad_step(state._optimizer) copy_params(source=master_params, target=model_params) torch.cuda.synchronize()
def on_batch_end(self, state): if not state.is_train: return self.accumulation_counter += 1 if not self.fp16: model = state.model optimizer = state.get_key(key="optimizer", inner_key=self.optimizer_key) loss = state.get_key(key="loss", inner_key=self.loss_key) loss.backward() if (self.accumulation_counter + 1) % self.accumulation_steps == 0: self.grad_step(optimizer=optimizer, optimizer_wd=self.optimizer_wd, grad_clip_fn=self.grad_clip_fn) model.zero_grad() self.accumulation_counter = 0 else: model = state.model model.zero_grad() optimizer = state.get_key(key="optimizer", inner_key=self.optimizer_key) loss = state.get_key(key="loss", inner_key=self.optimizer_key) scaled_loss = self.fp16_grad_scale * loss.float() scaled_loss.backward() master_params = list(optimizer.param_groups[0]["params"]) model_params = list( filter(lambda p: p.requires_grad, model.parameters())) copy_grads(source=model_params, target=master_params) for param in master_params: param.grad.data.mul_(1. / self.fp16_grad_scale) self.grad_step(optimizer=optimizer, optimizer_wd=self.optimizer_wd, grad_clip_fn=self.grad_clip_fn) copy_params(source=master_params, target=model_params) torch.cuda.synchronize()