Ejemplo n.º 1
0
 def test_many_tensors(self):
     """Test if output is correct for var-args."""
     ts = []
     exp_reg = 0.
     for i, shape in enumerate([
         (1, 2, 3),
         (2, 3, 4),
         (3, 4, 5),
     ]):
         t = torch.ones(*shape) * (i + 1)
         ts.append(t)
         exp_reg += numpy.prod(t.shape) * (i + 1)**2
     reg = l2_regularization(*ts)
     self.assertAlmostEqual(float(reg), exp_reg)
Ejemplo n.º 2
0
 def test_one_tensor(self):
     """Test if output is correct for a single tensor."""
     t = torch.ones(1, 2, 3, 4)
     reg = l2_regularization(t)
     self.assertAlmostEqual(float(reg), float(numpy.prod(t.shape)))