def test_many_tensors(self): """Test if output is correct for var-args.""" ts = [] exp_reg = 0. for i, shape in enumerate([ (1, 2, 3), (2, 3, 4), (3, 4, 5), ]): t = torch.ones(*shape) * (i + 1) ts.append(t) exp_reg += numpy.prod(t.shape) * (i + 1)**2 reg = l2_regularization(*ts) self.assertAlmostEqual(float(reg), exp_reg)
def test_one_tensor(self): """Test if output is correct for a single tensor.""" t = torch.ones(1, 2, 3, 4) reg = l2_regularization(t) self.assertAlmostEqual(float(reg), float(numpy.prod(t.shape)))