def testNxNred(self, datatype, ncols, npoints, mode): rtol, atol = 1e-5, 1e-8 # Configuration nAngles = int(npoints * (npoints - 1) / 2) # Expected values expctdLeftTop = torch.tensor(1., dtype=datatype) # Instantiation of target class target = OrthonormalTransform(n=npoints, mode=mode) target.angles = nn.init.uniform_(target.angles, a=0.0, b=2 * math.pi) target.angles.data[:npoints - 1] = torch.zeros(npoints - 1) # Actual values with torch.no_grad(): matrix = target.forward(torch.eye(npoints, dtype=datatype)) actualLeftTop = matrix[0, 0] #.numpy() # Evaluation message = "actualLeftTop=%s differs from %s" % (str(actualLeftTop), str(expctdLeftTop)) #self.assertTrue(np.isclose(actualLeftTop,expctdLeftTop,rtol=rtol,atol=atol),message) self.assertTrue( torch.isclose(actualLeftTop, expctdLeftTop, rtol=rtol, atol=atol), message)
def testCallWithAngles(self, datatype, ncols, mode): rtol, atol = 1e-4, 1e-7 # Expected values X = torch.randn(2, ncols, dtype=datatype) R = torch.tensor([[math.cos(math.pi / 4), -math.sin(math.pi / 4)], [math.sin(math.pi / 4), math.cos(math.pi / 4)]], dtype=datatype) if mode != 'Synthesis': expctdZ = R @ X else: expctdZ = R.T @ X # Instantiation of target class target = OrthonormalTransform(mode=mode) #target.angles.data = torch.tensor([math.pi/4]) target.angles = nn.init.constant_(target.angles, val=math.pi / 4) # Actual values with torch.no_grad(): actualZ = target.forward(X) # Evaluation self.assertTrue(torch.allclose(actualZ, expctdZ, rtol=rtol, atol=atol))
def testBackwardMultiColumns(self, datatype, ncols, mode): rtol, atol = 1e-4, 1e-7 # Configuration nPoints = 2 # Expected values X = torch.randn(nPoints, ncols, dtype=datatype, requires_grad=True) dLdZ = torch.randn(nPoints, ncols, dtype=datatype) R = torch.eye(nPoints, dtype=datatype) dRdW = torch.tensor([[0., -1.], [1., 0.]], dtype=datatype) if mode != 'Synthesis': expctddLdX = R.T @ dLdZ # = dZdX @ dLdZ expctddLdW = torch.sum(dLdZ * (dRdW @ X)) else: expctddLdX = R @ dLdZ # = dZdX @ dLdZ expctddLdW = torch.sum(dLdZ * (dRdW.T @ X)) # Instantiation of target class target = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode) # Actual values torch.autograd.set_detect_anomaly(True) Z = target.forward(X) target.zero_grad() Z.backward(dLdZ) actualdLdX = X.grad actualdLdW = target.angles.grad # Evaluation self.assertTrue( torch.allclose(actualdLdX, expctddLdX, rtol=rtol, atol=atol)) self.assertTrue( torch.allclose(actualdLdW, expctddLdW, rtol=rtol, atol=atol))
def __init__(self, number_of_channels=[], decimation_factor=[], no_dc_leakage=False, name=''): super(NsoltFinalRotation2dLayer, self).__init__() self.name = name self.number_of_channels = number_of_channels self.decimation_factor = decimation_factor self.description = "NSOLT final rotation " \ + "(ps,pa) = (" \ + str(self.number_of_channels[0]) + "," \ + str(self.number_of_channels[1]) + "), " \ + "(mv,mh) = (" \ + str(self.decimation_factor[Direction.VERTICAL]) + "," \ + str(self.decimation_factor[Direction.HORIZONTAL]) + ")" # Instantiation of orthormal transforms ps, pa = self.number_of_channels self.orthTransW0T = OrthonormalTransform(n=ps, mode='Synthesis') self.orthTransW0T.angles = nn.init.zeros_(self.orthTransW0T.angles) self.orthTransU0T = OrthonormalTransform(n=pa, mode='Synthesis') self.orthTransU0T.angles = nn.init.zeros_(self.orthTransU0T.angles) # No DC leakage self.no_dc_leakage = no_dc_leakage
def testBackward8x8RandAngPdAng4(self, mode, ncols): datatype = torch.double rtol, atol = 1e-4, 1e-7 # Configuration #mode = 'Synthesis' nPoints = 8 #ncols = 2 angs0 = 2 * math.pi * torch.rand(28, dtype=datatype) angs1 = angs0.clone() angs2 = angs0.clone() pdAng = 4 delta = 1e-4 angs1[pdAng] = angs0[pdAng] - delta / 2. angs2[pdAng] = angs0[pdAng] + delta / 2. # Expcted values X = torch.randn(nPoints, ncols, dtype=datatype, requires_grad=True) dLdZ = torch.randn(nPoints, ncols, dtype=datatype) omgs = OrthonormalMatrixGenerationSystem(dtype=datatype, partial_difference=False) R = omgs(angles=angs0, mus=1) dRdW = (omgs(angles=angs2, mus=1) - omgs(angles=angs1, mus=1)) / delta if mode != 'Synthesis': expctddLdX = R.T @ dLdZ # = dZdX @ dLdZ expctddLdW = torch.sum(dLdZ * (dRdW @ X)) else: expctddLdX = R @ dLdZ # = dZdX @ dLdZ expctddLdW = torch.sum(dLdZ * (dRdW.T @ X)) # Instantiation of target class target = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode) target.angles.data = angs0 # Actual values torch.autograd.set_detect_anomaly(True) Z = target.forward(X) target.zero_grad() Z.backward(dLdZ) actualdLdX = X.grad actualdLdW = target.angles.grad[pdAng] # Evaluation self.assertTrue( torch.allclose(actualdLdX, expctddLdX, rtol=rtol, atol=atol)) self.assertTrue( torch.allclose(actualdLdW, expctddLdW, rtol=rtol, atol=atol))
def __init__(self, number_of_channels=[], mode='Synthesis', name=''): super(NsoltIntermediateRotation2dLayer, self).__init__() self.name = name self.number_of_channels = number_of_channels self.description = mode \ + " NSOLT intermediate rotation " \ + "(ps,pa) = (" \ + str(self.number_of_channels[0]) + "," \ + str(self.number_of_channels[1]) + ")" # Instantiation of orthormal transforms ps,pa = self.number_of_channels self.orthTransUn = OrthonormalTransform(n=pa,mode=mode) self.orthTransUn.angles = nn.init.zeros_(self.orthTransUn.angles)
class NsoltIntermediateRotation2dLayer(nn.Module): """ NSOLTINTERMEDIATEROTATION2DLAYER コンポーネント別に入力(nComponents): nSamples x nRows x nCols x nChs コンポーネント別に出力(nComponents): nSamples x nRows x nCols x nChs Requirements: Python 3.7.x, PyTorch 1.7.x Copyright (c) 2020-2021, Shogo MURAMATSU All rights reserved. Contact address: Shogo MURAMATSU, Faculty of Engineering, Niigata University, 8050 2-no-cho Ikarashi, Nishi-ku, Niigata, 950-2181, JAPAN http://msiplab.eng.niigata-u.ac.jp/ """ def __init__(self, number_of_channels=[], mode='Synthesis', name=''): super(NsoltIntermediateRotation2dLayer, self).__init__() self.name = name self.number_of_channels = number_of_channels self.description = mode \ + " NSOLT intermediate rotation " \ + "(ps,pa) = (" \ + str(self.number_of_channels[0]) + "," \ + str(self.number_of_channels[1]) + ")" # Instantiation of orthormal transforms ps,pa = self.number_of_channels self.orthTransUn = OrthonormalTransform(n=pa,mode=mode) self.orthTransUn.angles = nn.init.zeros_(self.orthTransUn.angles) def forward(self,X): nSamples = X.size(dim=0) nrows = X.size(dim=1) ncols = X.size(dim=2) ps,pa = self.number_of_channels # Process Z = X.clone() Ya = X[:,:,:,ps:].view(-1,pa).T Za = self.orthTransUn.forward(Ya) Z[:,:,:,ps:] = Za.T.view(nSamples,nrows,ncols,pa) return Z @property def mode(self): return self.orthTransUn.mode
def test8x8(self, datatype, ncols, mode): rtol, atol = 1e-5, 1e-8 # Expected values expctdNorm = torch.tensor(1., dtype=datatype) # Instantiation of target class target = OrthonormalTransform(n=8, mode=mode) target.angles.data = torch.randn(28, dtype=datatype) # Actual values unitvec = torch.randn(8, ncols, dtype=datatype) unitvec /= unitvec.norm() with torch.no_grad(): actualNorm = target.forward(unitvec).norm() #.numpy() # Evaluation message = "actualNorm=%s differs from %s" % (str(actualNorm), str(expctdNorm)) #self.assertTrue(np.isclose(actualNorm,expctdNorm,rtol=rtol,atol=atol),message) self.assertTrue( torch.isclose(actualNorm, expctdNorm, rtol=rtol, atol=atol), message)
def testConstructor(self, datatype, ncols): rtol, atol = 1e-5, 1e-8 # Expected values X = torch.randn(2, ncols, dtype=datatype) expctdZ = X expctdNParams = 1 expctdMode = 'Analysis' # Instantiation of target class target = OrthonormalTransform() # Actual values with torch.no_grad(): actualZ = target.forward(X) actualNParams = len(target.parameters().__next__()) actualMode = target.mode # Evaluation self.assertTrue(isinstance(target, nn.Module)) self.assertTrue(torch.allclose(actualZ, expctdZ, rtol=rtol, atol=atol)) self.assertEqual(actualNParams, expctdNParams) self.assertEqual(actualMode, expctdMode)
def testBackwardAngsAndMus(self, datatype, mode, ncols): rtol, atol = 1e-4, 1e-7 # Configuration #mode = 'Analysis' nPoints = 2 #ncols = 1 mus = [1, -1] # Expected values X = torch.randn(nPoints, ncols, dtype=datatype, requires_grad=True) dLdZ = torch.randn(nPoints, ncols, dtype=datatype) # angle = 2.*math.pi*randn(1) angle = 2. * math.pi * gauss(mu=0., sigma=1.) #randn(1) R = torch.tensor([[math.cos(angle), -math.sin(angle)], [-math.sin(angle), -math.cos(angle)]], dtype=datatype) #.squeeze() dRdW = torch.tensor( [[-math.sin(angle), -math.cos(angle)], [-math.cos(angle), math.sin(angle)]], dtype=datatype) #.squeeze() if mode != 'Synthesis': expctddLdX = R.T @ dLdZ # = dZdX @ dLdZ expctddLdW = torch.sum(dLdZ * (dRdW @ X)) else: expctddLdX = R @ dLdZ # = dZdX @ dLdZ expctddLdW = torch.sum(dLdZ * (dRdW.T @ X)) # Instantiation of target class target = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode) target.angles = nn.init.constant_(target.angles, val=angle) target.mus = torch.tensor(mus, dtype=datatype) # Actual values torch.autograd.set_detect_anomaly(True) Z = target.forward(X) target.zero_grad() Z.backward(dLdZ) actualdLdX = X.grad actualdLdW = target.angles.grad # Evaluation self.assertTrue( torch.allclose(actualdLdX, expctddLdX, rtol=rtol, atol=atol)) self.assertTrue( torch.allclose(actualdLdW, expctddLdW, rtol=rtol, atol=atol))
def testGradCheckNxNRandAngMus(self, mode, ncols, npoints): # Configuration datatype = torch.double nPoints = npoints nAngs = int(nPoints * (nPoints - 1) / 2.) mus = (-1)**torch.randint(high=2, size=(nPoints, )) angs = 2. * math.pi * torch.randn(nAngs, dtype=datatype) # Expcted values X = torch.randn(nPoints, ncols, dtype=datatype, requires_grad=True) dLdZ = torch.randn(nPoints, ncols, dtype=datatype) # Instantiation of target class target = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode) target.angles.data = angs target.mus = mus torch.autograd.set_detect_anomaly(True) Z = target.forward(X) target.zero_grad() # Evaluation self.assertTrue(torch.autograd.gradcheck(target, (X, )))
def testInstantiationWithInvalidMode(self): mode = 'Invalid' # Instantiation of target class with self.assertRaises(InvalidMode): target = OrthonormalTransform(mode=mode)
def testBackword8x8RandAngMusPdAng7(self, mode, ncols): datatype = torch.double rtol, atol = 1e-4, 1e-7 # Configuration #mode = 'Synthesis' nPoints = 8 #ncols = 2 mus = [1, -1, 1, -1, 1, -1, 1, -1] angs0 = 2 * math.pi * torch.rand(28, dtype=datatype) angs1 = angs0.clone() angs2 = angs0.clone() pdAng = 7 delta = 1e-4 angs1[pdAng] = angs0[pdAng] - delta / 2. angs2[pdAng] = angs0[pdAng] + delta / 2. # Expcted values X = torch.randn(nPoints, ncols, dtype=datatype, requires_grad=False) dLdZ = torch.randn(nPoints, ncols, dtype=datatype) # Instantiation of target class target0 = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode) target0.angles.data = angs0 target0.mus = mus target1 = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode) target1.angles.data = angs1 target1.mus = mus target2 = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode) target2.angles.data = angs2 target2.mus = mus # Expctd values if mode == 'Analysis': bwmode = 'Synthesis' else: bwmode = 'Analysis' backprop = OrthonormalTransform(n=nPoints, dtype=datatype, mode=bwmode) backprop.angles.data = angs0 backprop.mus = mus torch.autograd.set_detect_anomaly(True) dZdW = (target2.forward(X) - target1.forward(X)) / delta # ~ d(R*X)/dW expctddLdW = torch.sum(dLdZ * dZdW) # ~ dLdW # Actual values X.detach() X.requires_grad = True Z = target0.forward(X) target0.zero_grad() #print(torch.autograd.gradcheck(target0,(X,angs0))) Z.backward(dLdZ) actualdLdW = target0.angles.grad[pdAng] # Evaluation self.assertTrue( torch.allclose(actualdLdW, expctddLdW, rtol=rtol, atol=atol))
def testBackward4x4RandAngPdAng1(self, mode, ncols): datatype = torch.double rtol, atol = 1e-2, 1e-5 # Configuration #mode = 'Synthesis' nPoints = 4 #ncols = 2 mus = [-1, -1, -1, -1] angs = 2. * math.pi * torch.randn(6, dtype=datatype) pdAng = 1 delta = 1e-4 # Expcted values X = torch.randn(nPoints, ncols, dtype=datatype, requires_grad=True) dLdZ = torch.randn(nPoints, ncols, dtype=datatype) R = torch.as_tensor( torch.tensor(mus).view(-1,1) * \ torch.tensor( [ [1, 0, 0, 0. ], [0, 1, 0, 0. ], [0, 0, math.cos(angs[5]), -math.sin(angs[5]) ], [0, 0, math.sin(angs[5]), math.cos(angs[5]) ] ] ) @ torch.tensor( [ [1, 0, 0, 0 ], [0, math.cos(angs[4]), 0, -math.sin(angs[4]) ], [0, 0, 1, 0 ], [0, math.sin(angs[4]), 0, math.cos(angs[4]) ] ] ) @ torch.tensor( [ [1, 0, 0, 0 ], [0, math.cos(angs[3]), -math.sin(angs[3]), 0 ], [0, math.sin(angs[3]), math.cos(angs[3]), 0 ], [0, 0, 0, 1 ] ] ) @ torch.tensor( [ [ math.cos(angs[2]), 0, 0, -math.sin(angs[2]) ], [0, 1, 0, 0 ], [0, 0, 1, 0 ], [ math.sin(angs[2]), 0, 0, math.cos(angs[2]) ] ] ) @ torch.tensor( [ [math.cos(angs[1]), 0, -math.sin(angs[1]), 0 ], [0, 1, 0, 0 ], [math.sin(angs[1]), 0, math.cos(angs[1]), 0 ], [0, 0, 0, 1 ] ] ) @ torch.tensor( [ [ math.cos(angs[0]), -math.sin(angs[0]), 0, 0 ], [ math.sin(angs[0]), math.cos(angs[0]), 0, 0 ], [ 0, 0, 1, 0 ], [ 0, 0, 0, 1 ] ] ),dtype=datatype) dRdW = torch.as_tensor( (1./delta) * torch.tensor(mus).view(-1,1) * \ torch.tensor( [ [1, 0, 0, 0. ], [0, 1, 0., 0. ], [0, 0, math.cos(angs[5]), -math.sin(angs[5]) ], [0., 0, math.sin(angs[5]), math.cos(angs[5]) ] ] ) @ torch.tensor( [ [1, 0, 0, 0 ], [0, math.cos(angs[4]), 0, -math.sin(angs[4]) ], [0, 0, 1, 0 ], [0, math.sin(angs[4]), 0, math.cos(angs[4]) ] ] ) @ torch.tensor( [ [1, 0, 0, 0 ], [0, math.cos(angs[3]), -math.sin(angs[3]), 0 ], [0, math.sin(angs[3]), math.cos(angs[3]), 0 ], [0, 0, 0, 1 ] ] ) @ torch.tensor( [ [ math.cos(angs[2]), 0, 0, -math.sin(angs[2]) ], [0, 1, 0, 0 ], [0, 0, 1, 0 ], [ math.sin(angs[2]), 0, 0, math.cos(angs[2]) ] ] ) @ ( torch.tensor( [ [math.cos(angs[1]+delta/2.), 0, -math.sin(angs[1]+delta/2.), 0 ], [0, 1, 0, 0 ], [math.sin(angs[1]+delta/2.), 0, math.cos(angs[1]+delta/2.), 0 ], [0, 0, 0, 1 ] ] ) - \ torch.tensor( [ [math.cos(angs[1]-delta/2.), 0, -math.sin(angs[1]-delta/2.), 0 ], [0, 1, 0, 0 ], [math.sin(angs[1]-delta/2.), 0, math.cos(angs[1]-delta/2.), 0 ], [0, 0, 0, 1 ] ] ) ) @ torch.tensor( [ [ math.cos(angs[0]), -math.sin(angs[0]), 0, 0 ], [ math.sin(angs[0]), math.cos(angs[0]), 0, 0 ], [ 0, 0, 1, 0 ], [ 0, 0, 0, 1 ] ] ),dtype=datatype) if mode != 'Synthesis': expctddLdX = R.T @ dLdZ # = dZdX @ dLdZ expctddLdW = torch.sum(dLdZ * (dRdW @ X)) else: expctddLdX = R @ dLdZ # = dZdX @ dLdZ expctddLdW = torch.sum(dLdZ * (dRdW.T @ X)) # Instantiation of target class target = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode) target.angles.data = angs target.mus = mus # Actual values torch.autograd.set_detect_anomaly(True) Z = target.forward(X) target.zero_grad() Z.backward(dLdZ) actualdLdX = X.grad actualdLdW = target.angles.grad[pdAng] # Evaluation self.assertTrue( torch.allclose(actualdLdX, expctddLdX, rtol=rtol, atol=atol)) self.assertTrue( torch.allclose(actualdLdW, expctddLdW, rtol=rtol, atol=atol))
def testForward4x4RandAngs(self, datatype, mode, ncols): rtol, atol = 1e-4, 1e-7 # Configuration #mode = 'Synthesis' nPoints = 4 #ncols = 2 mus = [-1, 1, -1, 1] angs = 2. * math.pi * torch.randn(6, dtype=datatype) # Expcted values X = torch.randn(nPoints, ncols, dtype=datatype) R = torch.as_tensor( torch.tensor(mus).view(-1,1) * \ torch.tensor( [ [1, 0, 0, 0. ], [0, 1, 0, 0. ], [0, 0, math.cos(angs[5]), -math.sin(angs[5]) ], [0, 0, math.sin(angs[5]), math.cos(angs[5]) ] ] ) @ torch.tensor( [ [1, 0, 0, 0 ], [0, math.cos(angs[4]), 0, -math.sin(angs[4]) ], [0, 0, 1, 0 ], [0, math.sin(angs[4]), 0, math.cos(angs[4]) ] ] ) @ torch.tensor( [ [1, 0, 0, 0 ], [0, math.cos(angs[3]), -math.sin(angs[3]), 0 ], [0, math.sin(angs[3]), math.cos(angs[3]), 0 ], [0, 0, 0, 1 ] ] ) @ torch.tensor( [ [ math.cos(angs[2]), 0, 0, -math.sin(angs[2]) ], [0, 1, 0, 0 ], [0, 0, 1, 0 ], [ math.sin(angs[2]), 0, 0, math.cos(angs[2]) ] ] ) @ torch.tensor( [ [math.cos(angs[1]), 0, -math.sin(angs[1]), 0 ], [0, 1, 0, 0 ], [math.sin(angs[1]), 0, math.cos(angs[1]), 0 ], [0, 0, 0, 1 ] ] ) @ torch.tensor( [ [ math.cos(angs[0]), -math.sin(angs[0]), 0, 0 ], [ math.sin(angs[0]), math.cos(angs[0]), 0, 0 ], [ 0, 0, 1, 0 ], [ 0, 0, 0, 1 ] ] ),dtype=datatype) if mode != 'Synthesis': expctdZ = R @ X else: expctdZ = R.T @ X # Instantiation of target class target = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode) target.angles.data = angs target.mus = mus # Actual values with torch.no_grad(): actualZ = target.forward(X) # Evaluation self.assertTrue(torch.allclose(actualZ, expctdZ, rtol=rtol, atol=atol))
def testSetInvalidMus(self): mus = 2 with self.assertRaises(InvalidMus): target = OrthonormalTransform() target.mus = mus
def testInstantiationWithInvalidMus(self): mus = 2 with self.assertRaises(InvalidMus): target = OrthonormalTransform(mus=mus)
def testSetInvalidMode(self): mode = 'Invalid' with self.assertRaises(InvalidMode): target = OrthonormalTransform() target.mode = 'InvalidMode'
class NsoltFinalRotation2dLayer(nn.Module): """ NSOLTFINALROTATION2DLAYER コンポーネント別に入力(nComponents): nSamples x nRows x nCols x nChs コンポーネント別に出力(nComponents): nSamples x nRows x nCols x nDecs Requirements: Python 3.7.x, PyTorch 1.7.x Copyright (c) 2020-2021, Shogo MURAMATSU All rights reserved. Contact address: Shogo MURAMATSU, Faculty of Engineering, Niigata University, 8050 2-no-cho Ikarashi, Nishi-ku, Niigata, 950-2181, JAPAN http://msiplab.eng.niigata-u.ac.jp/ """ def __init__(self, number_of_channels=[], decimation_factor=[], no_dc_leakage=False, name=''): super(NsoltFinalRotation2dLayer, self).__init__() self.name = name self.number_of_channels = number_of_channels self.decimation_factor = decimation_factor self.description = "NSOLT final rotation " \ + "(ps,pa) = (" \ + str(self.number_of_channels[0]) + "," \ + str(self.number_of_channels[1]) + "), " \ + "(mv,mh) = (" \ + str(self.decimation_factor[Direction.VERTICAL]) + "," \ + str(self.decimation_factor[Direction.HORIZONTAL]) + ")" # Instantiation of orthormal transforms ps, pa = self.number_of_channels self.orthTransW0T = OrthonormalTransform(n=ps, mode='Synthesis') self.orthTransW0T.angles = nn.init.zeros_(self.orthTransW0T.angles) self.orthTransU0T = OrthonormalTransform(n=pa, mode='Synthesis') self.orthTransU0T.angles = nn.init.zeros_(self.orthTransU0T.angles) # No DC leakage self.no_dc_leakage = no_dc_leakage def forward(self, X): nSamples = X.size(dim=0) nrows = X.size(dim=1) ncols = X.size(dim=2) ps, pa = self.number_of_channels stride = self.decimation_factor nDecs = stride[0] * stride[1] # math.prod(stride) # No DC leackage if self.no_dc_leakage: self.orthTransW0T.mus[0] = 1 self.orthTransW0T.angles.data[:ps-1] = \ torch.zeros(ps-1,dtype=self.orthTransW0T.angles.data.dtype) # Process Ys = X[:, :, :, :ps].view(-1, ps).T Ya = X[:, :, :, ps:].view(-1, pa).T ms = int(math.ceil(nDecs / 2.)) ma = int(math.floor(nDecs / 2.)) Zsa = torch.cat((self.orthTransW0T.forward(Ys)[:ms, :], self.orthTransU0T.forward(Ya)[:ma, :]), dim=0) return Zsa.T.view(nSamples, nrows, ncols, nDecs)