def test_clip_by_global_norm(self): tensors = [torch.randn([3, 4, 5]) for _ in range(10)] sqr_norm = self._sqr_norm(tensors) clipped_tensors, _ = tensor_utils.clip_by_global_norm(tensors, clip_norm=1.0, in_place=False) self.assertTensorNotClose(self._sqr_norm(tensors), torch.as_tensor(1.0)) self.assertTensorClose(self._sqr_norm(tensors), sqr_norm) self.assertTensorClose(self._sqr_norm(clipped_tensors), torch.as_tensor(1.0))
def step(self, closure=None): """This function first clips the gradients if needed, then call the parent's ``step()`` function. """ if self._lr_scheduler is not None: lr = float(self._lr_scheduler()) for param_group in self.param_groups: param_group['lr'] = lr if self._gradient_clipping is not None: params = [] for param_group in self.param_groups: params.extend(param_group["params"]) grads = alf.nest.map_structure(lambda p: p.grad, params) if self._clip_by_global_norm: _, global_norm = tensor_utils.clip_by_global_norm( grads, self._gradient_clipping, in_place=True) if alf.summary.should_record_summaries(): alf.summary.scalar("global_grad_norm/%s" % self.name, global_norm) else: tensor_utils.clip_by_norms(grads, self._gradient_clipping, in_place=True) super(NewCls, self).step(closure=closure)
def test_clip_by_global_norm_in_place(self): tensors = [torch.randn([3, 4, 5]) for _ in range(10)] tensor_utils.clip_by_global_norm(tensors, clip_norm=1.0, in_place=True) self.assertTensorClose( sum([torch.norm(torch.reshape(t, [-1]))**2 for t in tensors]), torch.as_tensor(1.0))