コード例 #1
0
    def test_clip_by_norms(self):
        tensor = torch.ones([5])
        clipped_tensor = tensor_utils.clip_by_norms(tensor,
                                                    clip_norm=1.0,
                                                    in_place=True)
        self.assertTensorClose(self._sqr_norm(clipped_tensor),
                               torch.as_tensor(1.0))
        self.assertTensorClose(self._sqr_norm(tensor), torch.as_tensor(1.0))

        tensors = [torch.randn([3, 4, 5]) for _ in range(10)]
        tensor_utils.clip_by_norms(tensors, clip_norm=1.0, in_place=True)
        for t in tensors:
            self.assertTensorClose(self._sqr_norm(t), torch.as_tensor(1.0))
コード例 #2
0
 def step(self, closure=None):
     """This function first clips the gradients if needed, then call the
     parent's ``step()`` function.
     """
     if self._lr_scheduler is not None:
         lr = float(self._lr_scheduler())
         for param_group in self.param_groups:
             param_group['lr'] = lr
     if self._gradient_clipping is not None:
         params = []
         for param_group in self.param_groups:
             params.extend(param_group["params"])
         grads = alf.nest.map_structure(lambda p: p.grad, params)
         if self._clip_by_global_norm:
             _, global_norm = tensor_utils.clip_by_global_norm(
                 grads, self._gradient_clipping, in_place=True)
             if alf.summary.should_record_summaries():
                 alf.summary.scalar("global_grad_norm/%s" % self.name,
                                    global_norm)
         else:
             tensor_utils.clip_by_norms(grads,
                                        self._gradient_clipping,
                                        in_place=True)
     super(NewCls, self).step(closure=closure)
コード例 #3
0
 def test_no_clip_by_norms(self):
     tensor = torch.ones([5])
     tensor_utils.clip_by_norms(tensor, clip_norm=100.0, in_place=True)
     self.assertTensorNotClose(self._sqr_norm(tensor),
                               torch.as_tensor(100.0))