def testBackwardGrayscaleAnalysisMode(self, datatype, nchs, nrows, ncols, mus): rtol, atol = 1e-3, 1e-6 omgs = OrthonormalMatrixGenerationSystem(dtype=datatype, partial_difference=False) # Parameters nSamples = 8 nChsTotal = sum(nchs) nAngles = int((nChsTotal - 2) * nChsTotal / 8) angles = torch.randn(nAngles, dtype=datatype) # nSamples x nRows x nCols x nChsTotal X = torch.randn(nSamples, nrows, ncols, nChsTotal, dtype=datatype, requires_grad=True) dLdZ = torch.randn(nSamples, nrows, ncols, nChsTotal, dtype=datatype) # Expected values ps, pa = nchs UnT = omgs(angles, mus).T # dLdX = dZdX x dLdZ expctddLdX = dLdZ.clone() Ya = dLdZ[:, :, :, ps:].view(nSamples * nrows * ncols, pa).T # pa * n Za = UnT @ Ya expctddLdX[:, :, :, ps:] = Za.T.view(nSamples, nrows, ncols, pa) # dLdWi = <dLdZ,(dVdWi)X> expctddLdW_U = torch.zeros(nAngles, dtype=datatype) omgs.partial_difference = True for iAngle in range(nAngles): dUn = omgs(angles, mus, index_pd_angle=iAngle) Xa = X[:, :, :, ps:].view(-1, pa).T Za = dUn @ Xa # pa x n expctddLdW_U[iAngle] = torch.sum(Ya * Za) # Instantiation of target class layer = NsoltIntermediateRotation2dLayer(number_of_channels=nchs, mode='Analysis', name='Vn') layer.orthTransUn.angles.data = angles layer.orthTransUn.mus = mus # Actual values torch.autograd.set_detect_anomaly(True) Z = layer.forward(X) layer.zero_grad() Z.backward(dLdZ) actualdLdX = X.grad actualdLdW_U = layer.orthTransUn.angles.grad # Evaluation self.assertEqual(actualdLdX.dtype, datatype) self.assertEqual(actualdLdW_U.dtype, datatype) self.assertTrue( torch.allclose(actualdLdX, expctddLdX, rtol=rtol, atol=atol)) self.assertTrue( torch.allclose(actualdLdW_U, expctddLdW_U, rtol=rtol, atol=atol)) self.assertTrue(Z.requires_grad)
def testSetAngles(self, datatype): rtol, atol = 1e-5, 1e-8 # Expected values expctdM = torch.eye(2, dtype=datatype) # Instantiation of target class omgs = OrthonormalMatrixGenerationSystem(dtype=datatype) # Actual values actualM = omgs(angles=0, mus=1) # Evaluation self.assertTrue(torch.allclose(actualM, expctdM, rtol=rtol, atol=atol)) # Expected values expctdM = torch.tensor( [[math.cos(math.pi / 4), -math.sin(math.pi / 4)], [math.sin(math.pi / 4), math.cos(math.pi / 4)]], dtype=datatype) actualM = omgs(angles=math.pi / 4, mus=1) # Evaluation self.assertTrue(torch.allclose(actualM, expctdM, rtol=rtol, atol=atol))
def testPartialDifferenceSetAngles(self, datatype): rtol, atol = 1e-4, 1e-7 # Expected values expctdM = torch.tensor([[0., -1.], [1., 0.]], dtype=datatype) # Instantiation of target class omgs = OrthonormalMatrixGenerationSystem(dtype=datatype, partial_difference=True) # Actual values actualM = omgs(angles=0, mus=1, index_pd_angle=0) # Evaluation self.assertTrue(torch.allclose(actualM, expctdM, rtol=rtol, atol=atol)) # Expected values expctdM = torch.tensor([[ math.cos(math.pi / 4 + math.pi / 2), -math.sin(math.pi / 4 + math.pi / 2) ], [ math.sin(math.pi / 4 + math.pi / 2), math.cos(math.pi / 4 + math.pi / 2) ]], dtype=datatype) # Actual values actualM = omgs(angles=math.pi / 4, mus=1, index_pd_angle=0) # Evaluation self.assertTrue(torch.allclose(actualM, expctdM, rtol=rtol, atol=atol))
def testPartialDifference4x4RandAngPdAng1(self, datatype): rtol, atol = 1e-1, 1e-3 # Expcted values mus = [-1, -1, -1, -1] angs = 2 * math.pi * torch.rand(6) pdAng = 1 delta = 1e-3 expctdM = torch.as_tensor( (1/delta)*torch.tensor(mus).view(-1,1) * \ torch.tensor( [ [1, 0, 0, 0. ], [0, 1, 0., 0. ], [0, 0, math.cos(angs[5]), -math.sin(angs[5]) ], [0., 0, math.sin(angs[5]), math.cos(angs[5]) ] ] ) @ torch.tensor( [ [1, 0, 0, 0 ], [0, math.cos(angs[4]), 0, -math.sin(angs[4]) ], [0, 0, 1, 0 ], [0, math.sin(angs[4]), 0, math.cos(angs[4]) ] ] ) @ torch.tensor( [ [1, 0, 0, 0 ], [0, math.cos(angs[3]), -math.sin(angs[3]), 0 ], [0, math.sin(angs[3]), math.cos(angs[3]), 0 ], [0, 0, 0, 1 ] ] ) @ torch.tensor( [ [ math.cos(angs[2]), 0, 0, -math.sin(angs[2]) ], [0, 1, 0, 0 ], [0, 0, 1, 0 ], [ math.sin(angs[2]), 0, 0, math.cos(angs[2]) ] ] ) @ ( torch.tensor( [ [math.cos(angs[1]+delta/2), 0, -math.sin(angs[1]+delta/2), 0 ], [0, 1, 0, 0 ], [math.sin(angs[1]+delta/2), 0, math.cos(angs[1]+delta/2), 0 ], [0, 0, 0, 1 ] ] ) - \ torch.tensor( [ [math.cos(angs[1]-delta/2), 0, -math.sin(angs[1]-delta/2), 0 ], [0, 1, 0, 0 ], [math.sin(angs[1]-delta/2), 0, math.cos(angs[1]-delta/2), 0 ], [0, 0, 0, 1 ] ] ) ) @ torch.tensor( [ [ math.cos(angs[0]), -math.sin(angs[0]), 0, 0 ], [ math.sin(angs[0]), math.cos(angs[0]), 0, 0 ], [ 0, 0, 1, 0 ], [ 0, 0, 0, 1 ] ] ),dtype=datatype) # Instantiation of target class omgs = OrthonormalMatrixGenerationSystem(dtype=datatype, partial_difference=True) # Actual values actualM = omgs(angles=angs, mus=mus, index_pd_angle=pdAng) # Evaluation self.assertTrue(torch.allclose(actualM, expctdM, rtol=rtol, atol=atol))
def testPredictGrayscaleWithRandomAnglesNoDcLeackage(self, nchs, stride, nrows, ncols, datatype,mus): rtol,atol=1e-5,1e-8 gen = OrthonormalMatrixGenerationSystem(dtype=datatype) # Parameters nSamples = 8 nDecs = stride[0]*stride[1] # math.prod(stride) nChsTotal = sum(nchs) # nSamples x nRows x nCols x nDecs X = torch.randn(nSamples,nrows,ncols,nDecs,dtype=datatype) angles = torch.randn(int((nChsTotal-2)*nChsTotal/4),dtype=datatype) # Expected values # nSamples x nRows x nCols x nChs ps,pa = nchs nAngsW = int(len(angles)/2) angsW,angsU = angles[:nAngsW],angles[nAngsW:] angsWNoDcLeak = angsW.clone() angsWNoDcLeak[:ps-1] = torch.zeros(ps-1,dtype=angles.dtype) musW,musU = mus*torch.ones(ps,dtype=datatype),mus*torch.ones(pa,dtype=datatype) musW[0] = 1 W0,U0 = gen(angsWNoDcLeak,musW),gen(angsU,musU) ms,ma = int(math.ceil(nDecs/2.)), int(math.floor(nDecs/2.)) Zsa = torch.zeros(nChsTotal,nrows*ncols*nSamples,dtype=datatype) Ys = X[:,:,:,:ms].view(-1,ms).T Zsa[:ps,:] = W0[:,:ms] @ Ys if ma > 0: Ya = X[:,:,:,ms:].view(-1,ma).T Zsa[ps:,:] = U0[:,:ma] @ Ya expctdZ = Zsa.T.view(nSamples,nrows,ncols,nChsTotal) # Instantiation of target class layer = NsoltInitialRotation2dLayer( number_of_channels=nchs, decimation_factor=stride, no_dc_leakage=True, name='V0') layer.orthTransW0.angles.data = angsW layer.orthTransW0.mus = musW layer.orthTransU0.angles.data = angsU layer.orthTransU0.mus = musU # Actual values with torch.no_grad(): actualZ = layer.forward(X) # Evaluation self.assertEqual(actualZ.dtype,datatype) self.assertTrue(torch.allclose(actualZ,expctdZ,rtol=rtol,atol=atol)) self.assertFalse(actualZ.requires_grad)
def testBackward8x8RandAngMusPdAng13(self, mode, ncols): datatype = torch.double rtol, atol = 1e-4, 1e-7 # Configuration #mode = 'Synthesis' nPoints = 8 #ncols = 2 mus = [1, 1, 1, 1, -1, -1, -1, -1] angs0 = 2 * math.pi * torch.rand(28, dtype=datatype) angs1 = angs0.clone() angs2 = angs0.clone() pdAng = 13 delta = 1e-4 angs1[pdAng] = angs0[pdAng] - delta / 2. angs2[pdAng] = angs0[pdAng] + delta / 2. # Expcted values X = torch.randn(nPoints, ncols, dtype=datatype, requires_grad=True) dLdZ = torch.randn(nPoints, ncols, dtype=datatype) omgs = OrthonormalMatrixGenerationSystem(dtype=datatype, partial_difference=False) R = omgs(angles=angs0, mus=mus) dRdW = (omgs(angles=angs2, mus=mus) - omgs(angles=angs1, mus=mus)) / delta if mode != 'Synthesis': expctddLdX = R.T @ dLdZ # = dZdX @ dLdZ expctddLdW = torch.sum(dLdZ * (dRdW @ X)) else: expctddLdX = R @ dLdZ # = dZdX @ dLdZ expctddLdW = torch.sum(dLdZ * (dRdW.T @ X)) # Instantiation of target class target = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode) target.angles.data = angs0 target.mus = mus # Actual values torch.autograd.set_detect_anomaly(True) Z = target.forward(X) target.zero_grad() Z.backward(dLdZ) actualdLdX = X.grad actualdLdW = target.angles.grad[pdAng] # Evaluation self.assertTrue( torch.allclose(actualdLdX, expctddLdX, rtol=rtol, atol=atol)) self.assertTrue( torch.allclose(actualdLdW, expctddLdW, rtol=rtol, atol=atol))
def test4x4RandAngs(self, datatype): rtol, atol = 1e-4, 1e-7 # Expcted values mus = [-1, 1, -1, 1] angs = 2 * math.pi * torch.rand(6) expctdM = torch.as_tensor( torch.tensor(mus).view(-1,1) * \ torch.tensor( [ [1, 0, 0, 0. ], [0, 1, 0, 0. ], [0, 0, math.cos(angs[5]), -math.sin(angs[5]) ], [0, 0, math.sin(angs[5]), math.cos(angs[5]) ] ] ) @ torch.tensor( [ [1, 0, 0, 0 ], [0, math.cos(angs[4]), 0, -math.sin(angs[4]) ], [0, 0, 1, 0 ], [0, math.sin(angs[4]), 0, math.cos(angs[4]) ] ] ) @ torch.tensor( [ [1, 0, 0, 0 ], [0, math.cos(angs[3]), -math.sin(angs[3]), 0 ], [0, math.sin(angs[3]), math.cos(angs[3]), 0 ], [0, 0, 0, 1 ] ] ) @ torch.tensor( [ [ math.cos(angs[2]), 0, 0, -math.sin(angs[2]) ], [0, 1, 0, 0 ], [0, 0, 1, 0 ], [ math.sin(angs[2]), 0, 0, math.cos(angs[2]) ] ] ) @ torch.tensor( [ [math.cos(angs[1]), 0, -math.sin(angs[1]), 0 ], [0, 1, 0, 0 ], [math.sin(angs[1]), 0, math.cos(angs[1]), 0 ], [0, 0, 0, 1 ] ] ) @ torch.tensor( [ [ math.cos(angs[0]), -math.sin(angs[0]), 0, 0 ], [ math.sin(angs[0]), math.cos(angs[0]), 0, 0 ], [ 0, 0, 1, 0 ], [ 0, 0, 0, 1 ] ] ),dtype=datatype) # Instantiation of target class omgs = OrthonormalMatrixGenerationSystem(dtype=datatype) # Actual values actualM = omgs(angles=angs, mus=mus) # Evaluation self.assertTrue(torch.allclose(actualM, expctdM, rtol=rtol, atol=atol))
def testPartialDifference8x8RandAngPdAng13(self, datatype): rtol, atol = 1e-1, 1e-3 # Expcted values pdAng = 13 delta = 1e-3 angs0 = 2 * math.pi * torch.rand(28) angs1 = angs0.clone() angs2 = angs0.clone() angs1[pdAng] = angs0[pdAng] - delta / 2 angs2[pdAng] = angs0[pdAng] + delta / 2 # Instantiation of target class omgs = OrthonormalMatrixGenerationSystem(dtype=datatype, partial_difference=False) expctdM = (omgs(angles=angs2, mus=1) - omgs(angles=angs1, mus=1)) / delta # Instantiation of target class omgs.partial_difference = True actualM = omgs(angles=angs0, mus=1, index_pd_angle=pdAng) # Evaluation self.assertTrue(torch.allclose(actualM, expctdM, rtol=rtol, atol=atol))
def testConstructor(self, datatype): rtol, atol = 1e-5, 1e-8 # Expected values expctdM = torch.eye(2, dtype=datatype) # Instantiation of target class omgs = OrthonormalMatrixGenerationSystem(dtype=datatype) # Actual values angles = 0 mus = 1 actualM = omgs(angles=angles, mus=mus) # Evaluation self.assertTrue(torch.allclose(actualM, expctdM, rtol=rtol, atol=atol))
def testPredictGrayscaleWithRandomAngles(self, datatype, nchs, stride, nrows, ncols): rtol, atol = 1e-3, 1e-6 gen = OrthonormalMatrixGenerationSystem(dtype=datatype) # Parameters nSamples = 8 nDecs = stride[0] * stride[1] # math.prod(stride) nChsTotal = sum(nchs) # nSamples x nRows x nCols x nChs X = torch.randn(nSamples, nrows, ncols, nChsTotal, dtype=datatype) angles = torch.randn(int((nChsTotal - 2) * nChsTotal / 4), dtype=datatype) # Expected values # nSamples x nRows x nCols x nDecs ps, pa = nchs nAngsW = int(len(angles) / 2) angsW, angsU = angles[:nAngsW], angles[nAngsW:] W0T, U0T = gen(angsW).T, gen(angsU).T Ys = X[:, :, :, :ps].view(-1, ps).T Ya = X[:, :, :, ps:].view(-1, pa).T ms, ma = int(math.ceil(nDecs / 2.)), int(math.floor(nDecs / 2.)) Zsa = torch.cat((W0T[:ms, :] @ Ys, U0T[:ma, :] @ Ya), dim=0) expctdZ = Zsa.T.view(nSamples, nrows, ncols, nDecs) # Instantiation of target class layer = NsoltFinalRotation2dLayer(number_of_channels=nchs, decimation_factor=stride, name='V0~') layer.orthTransW0T.angles.data = angsW layer.orthTransW0T.mus = 1 layer.orthTransU0T.angles.data = angsU layer.orthTransU0T.mus = 1 # Actual values with torch.no_grad(): actualZ = layer.forward(X) # Evaluation self.assertEqual(actualZ.dtype, datatype) self.assertTrue(torch.allclose(actualZ, expctdZ, rtol=rtol, atol=atol)) self.assertFalse(actualZ.requires_grad)
def testPredictGrayscaleAnalysisMode(self, nchs, stride, nrows, ncols, mus, datatype): rtol, atol = 1e-5, 1e-8 gen = OrthonormalMatrixGenerationSystem(dtype=datatype) # Parameters nSamples = 8 nChsTotal = sum(nchs) # nSamples x nRows x nCols x nChsTotal X = torch.randn(nSamples, nrows, ncols, nChsTotal, dtype=datatype, requires_grad=True) angles = torch.randn(int((nChsTotal - 2) * nChsTotal / 8), dtype=datatype) # Expected values # nSamples x nRows x nCols x nChsTotal ps, pa = nchs Un = gen(angles, mus) expctdZ = X.clone() Ya = X[:, :, :, ps:].view(-1, pa).T Za = Un @ Ya expctdZ[:, :, :, ps:] = Za.T.view(nSamples, nrows, ncols, pa) # Instantiation of target class layer = NsoltIntermediateRotation2dLayer(number_of_channels=nchs, name='Vn', mode='Analysis') layer.orthTransUn.angles.data = angles layer.orthTransUn.mus = mus # Actual values with torch.no_grad(): actualZ = layer.forward(X) # Evaluation self.assertEqual(actualZ.dtype, datatype) self.assertTrue(torch.allclose(actualZ, expctdZ, rtol=rtol, atol=atol)) self.assertFalse(actualZ.requires_grad)
def test8x8(self, datatype): rtol, atol = 1e-5, 1e-8 # Expected values expctdNorm = torch.tensor(1., dtype=datatype) # Instantiation of target class ang = 2 * math.pi * torch.rand(28) omgs = OrthonormalMatrixGenerationSystem(dtype=datatype) # Actual values unitvec = torch.randn(8, dtype=datatype) unitvec /= unitvec.norm() actualNorm = omgs(angles=ang, mus=1).mv(unitvec).norm() #.numpy() # Evaluation message = "actualNorm=%s differs from %s" % (str(actualNorm), str(expctdNorm)) self.assertTrue( torch.isclose(actualNorm, expctdNorm, rtol=rtol, atol=atol), message)
def test8x8red(self, datatype): rtol, atol = 1e-5, 1e-8 # Expected values expctdLeftTop = torch.tensor(1., dtype=datatype) # Instantiation of target class ang = 2 * math.pi * torch.rand(28) nSize = 8 ang[:nSize - 1] = torch.zeros(nSize - 1) omgs = OrthonormalMatrixGenerationSystem(dtype=datatype) # Actual values matrix = omgs(angles=ang, mus=1) actualLeftTop = matrix[0, 0] # Evaluation message = "actualLeftTop=%s differs from %s" % (str(actualLeftTop), str(expctdLeftTop)) self.assertTrue( torch.isclose(actualLeftTop, expctdLeftTop, rtol=rtol, atol=atol), message)
def testBackwardWithRandomAnglesNoDcLeackage(self, datatype, nchs, stride, nrows, ncols, mus): rtol, atol = 1e-2, 1e-5 omgs = OrthonormalMatrixGenerationSystem(dtype=datatype, partial_difference=False) # Parameters nSamples = 8 nDecs = stride[0] * stride[1] # math.prod(stride) nChsTotal = sum(nchs) nAnglesH = int((nChsTotal - 2) * nChsTotal / 8) anglesW = torch.randn(nAnglesH, dtype=datatype) anglesU = torch.randn(nAnglesH, dtype=datatype) # nSamples x nRows x nCols x nChs X = torch.randn(nSamples, nrows, ncols, nChsTotal, dtype=datatype, requires_grad=True) dLdZ = torch.randn(nSamples, nrows, ncols, nDecs, dtype=datatype) # Expected values ps, pa = nchs anglesWNoDcLeak = anglesW.clone() anglesWNoDcLeak[:ps - 1] = torch.zeros(ps - 1, dtype=datatype) musW, musU = mus * torch.ones(ps, dtype=datatype), mus * torch.ones( pa, dtype=datatype) musW[0] = 1 W0 = omgs(anglesWNoDcLeak, musW) U0 = omgs(anglesU, musU) # dLdX = dZdX x dLdZ ms, ma = int(math.ceil(nDecs / 2.)), int(math.floor(nDecs / 2.)) Ys = dLdZ[:, :, :, :ms].view(nSamples * nrows * ncols, ms).T # ms x n Ya = dLdZ[:, :, :, ms:].view(nSamples * nrows * ncols, ma).T # ma x n Y = torch.cat( ( W0[:, :ms] @ Ys, # ps x ms @ ms x n U0[:, :ma] @ Ya), dim=0) # pa x ma @ ma x n expctddLdX = Y.T.view(nSamples, nrows, ncols, nChsTotal) # n x (ps+pa) -> N x R x C X P # dLdWi = <dLdZ,(dVdWi)X> expctddLdW_W = torch.zeros(nAnglesH, dtype=datatype) expctddLdW_U = torch.zeros(nAnglesH, dtype=datatype) omgs.partial_difference = True for iAngle in range(nAnglesH): dW0_T = omgs(anglesWNoDcLeak, musW, index_pd_angle=iAngle).T dU0_T = omgs(anglesU, musU, index_pd_angle=iAngle).T Xs = X[:, :, :, :ps].view(-1, ps).T Xa = X[:, :, :, ps:].view(-1, pa).T Zs = dW0_T[:ms, :] @ Xs # ms x n Za = dU0_T[:ma, :] @ Xa # ma x n expctddLdW_W[iAngle] = torch.sum(Ys[:ms, :] * Zs) expctddLdW_U[iAngle] = torch.sum(Ya[:ma, :] * Za) # Instantiation of target class layer = NsoltFinalRotation2dLayer(number_of_channels=nchs, decimation_factor=stride, no_dc_leakage=True, name='V0~') layer.orthTransW0T.angles.data = anglesW layer.orthTransW0T.mus = mus layer.orthTransU0T.angles.data = anglesU layer.orthTransU0T.mus = mus # Actual values torch.autograd.set_detect_anomaly(True) Z = layer.forward(X) layer.zero_grad() Z.backward(dLdZ) actualdLdX = X.grad actualdLdW_W = layer.orthTransW0T.angles.grad actualdLdW_U = layer.orthTransU0T.angles.grad # Evaluation self.assertEqual(actualdLdX.dtype, datatype) self.assertEqual(actualdLdW_W.dtype, datatype) self.assertEqual(actualdLdW_U.dtype, datatype) self.assertTrue( torch.allclose(actualdLdX, expctddLdX, rtol=rtol, atol=atol)) self.assertTrue( torch.allclose(actualdLdW_W, expctddLdW_W, rtol=rtol, atol=atol)) self.assertTrue( torch.allclose(actualdLdW_U, expctddLdW_U, rtol=rtol, atol=atol)) self.assertTrue(Z.requires_grad)
def testBackwardGrayscaleWithRandomAngles(self, nchs, stride, nrows, ncols, datatype): rtol,atol=1e-3,1e-6 omgs = OrthonormalMatrixGenerationSystem(dtype=datatype,partial_difference=False) # Parameters nSamples = 8 nDecs = stride[0]*stride[1] # math.prod(stride) nChsTotal = sum(nchs) nAnglesH = int((nChsTotal-2)*nChsTotal/8) anglesW = torch.randn(nAnglesH,dtype=datatype) anglesU = torch.randn(nAnglesH,dtype=datatype) mus = 1 # nSamples x nRows x nCols x nDecs X = torch.randn(nSamples,nrows,ncols,nDecs,dtype=datatype,requires_grad=True) dLdZ = torch.randn(nSamples,nrows,ncols,nChsTotal,dtype=datatype) # Expected values ps,pa = nchs W0T = omgs(anglesW,mus).T U0T = omgs(anglesU,mus).T # dLdX = dZdX x dLdZ ms,ma = int(math.ceil(nDecs/2.)),int(math.floor(nDecs/2.)) Ys = dLdZ[:,:,:,:ps].view(nSamples*nrows*ncols,ps).T # ps * n Ya = dLdZ[:,:,:,ps:].view(nSamples*nrows*ncols,pa).T # pa * n Y = torch.cat( ( W0T[:ms,:] @ Ys, # ms x ps @ ps x n U0T[:ma,:] @ Ya ), dim=0) # ma x pa @ pa x n expctddLdX = Y.T.view(nSamples,nrows,ncols,nDecs) # n x (ms+ma) # dLdWi = <dLdZ,(dVdWi)X> expctddLdW_W = torch.zeros(nAnglesH,dtype=datatype) expctddLdW_U = torch.zeros(nAnglesH,dtype=datatype) omgs.partial_difference = True for iAngle in range(nAnglesH): dW0 = omgs(anglesW,mus,index_pd_angle=iAngle) Xs = X[:,:,:,:ms].view(-1,ms).T Zs = dW0[:,:ms] @ Xs # ps x n expctddLdW_W[iAngle] = torch.sum(Ys * Zs) # ps x n if ma>0: dU0 = omgs(anglesU,mus,index_pd_angle=iAngle) Xa = X[:,:,:,ms:].view(-1,ma).T Za = dU0[:,:ma] @ Xa # pa x n expctddLdW_U[iAngle] = torch.sum(Ya * Za) # pa x n # Instantiation of target class layer = NsoltInitialRotation2dLayer( number_of_channels=nchs, decimation_factor=stride, name='V0') layer.orthTransW0.angles.data = anglesW layer.orthTransW0.mus = mus layer.orthTransU0.angles.data = anglesU layer.orthTransU0.mus = mus # Actual values torch.autograd.set_detect_anomaly(True) Z = layer.forward(X) layer.zero_grad() Z.backward(dLdZ) actualdLdX = X.grad actualdLdW_W = layer.orthTransW0.angles.grad actualdLdW_U = layer.orthTransU0.angles.grad # Evaluation self.assertEqual(actualdLdX.dtype,datatype) self.assertEqual(actualdLdW_W.dtype,datatype) self.assertEqual(actualdLdW_U.dtype,datatype) self.assertTrue(torch.allclose(actualdLdX,expctddLdX,rtol=rtol,atol=atol)) self.assertTrue(torch.allclose(actualdLdW_W,expctddLdW_W,rtol=rtol,atol=atol)) self.assertTrue(torch.allclose(actualdLdW_U,expctddLdW_U,rtol=rtol,atol=atol)) self.assertTrue(Z.requires_grad)