Exemplo n.º 1
0
    def testNxNred(self, datatype, ncols, npoints, mode):
        rtol, atol = 1e-5, 1e-8

        # Configuration
        nAngles = int(npoints * (npoints - 1) / 2)

        # Expected values
        expctdLeftTop = torch.tensor(1., dtype=datatype)

        # Instantiation of target class
        target = OrthonormalTransform(n=npoints, mode=mode)
        target.angles = nn.init.uniform_(target.angles, a=0.0, b=2 * math.pi)
        target.angles.data[:npoints - 1] = torch.zeros(npoints - 1)

        # Actual values
        with torch.no_grad():
            matrix = target.forward(torch.eye(npoints, dtype=datatype))
        actualLeftTop = matrix[0, 0]  #.numpy()

        # Evaluation
        message = "actualLeftTop=%s differs from %s" % (str(actualLeftTop),
                                                        str(expctdLeftTop))
        #self.assertTrue(np.isclose(actualLeftTop,expctdLeftTop,rtol=rtol,atol=atol),message)
        self.assertTrue(
            torch.isclose(actualLeftTop, expctdLeftTop, rtol=rtol, atol=atol),
            message)
Exemplo n.º 2
0
    def testCallWithAngles(self, datatype, ncols, mode):
        rtol, atol = 1e-4, 1e-7

        # Expected values
        X = torch.randn(2, ncols, dtype=datatype)
        R = torch.tensor([[math.cos(math.pi / 4), -math.sin(math.pi / 4)],
                          [math.sin(math.pi / 4),
                           math.cos(math.pi / 4)]],
                         dtype=datatype)
        if mode != 'Synthesis':
            expctdZ = R @ X
        else:
            expctdZ = R.T @ X

        # Instantiation of target class
        target = OrthonormalTransform(mode=mode)
        #target.angles.data = torch.tensor([math.pi/4])
        target.angles = nn.init.constant_(target.angles, val=math.pi / 4)

        # Actual values
        with torch.no_grad():
            actualZ = target.forward(X)

        # Evaluation
        self.assertTrue(torch.allclose(actualZ, expctdZ, rtol=rtol, atol=atol))
Exemplo n.º 3
0
    def testBackwardMultiColumns(self, datatype, ncols, mode):
        rtol, atol = 1e-4, 1e-7

        # Configuration
        nPoints = 2

        # Expected values
        X = torch.randn(nPoints, ncols, dtype=datatype, requires_grad=True)
        dLdZ = torch.randn(nPoints, ncols, dtype=datatype)
        R = torch.eye(nPoints, dtype=datatype)
        dRdW = torch.tensor([[0., -1.], [1., 0.]], dtype=datatype)
        if mode != 'Synthesis':
            expctddLdX = R.T @ dLdZ  # = dZdX @ dLdZ
            expctddLdW = torch.sum(dLdZ * (dRdW @ X))
        else:
            expctddLdX = R @ dLdZ  # = dZdX @ dLdZ
            expctddLdW = torch.sum(dLdZ * (dRdW.T @ X))

        # Instantiation of target class
        target = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode)

        # Actual values
        torch.autograd.set_detect_anomaly(True)
        Z = target.forward(X)
        target.zero_grad()
        Z.backward(dLdZ)
        actualdLdX = X.grad
        actualdLdW = target.angles.grad

        # Evaluation
        self.assertTrue(
            torch.allclose(actualdLdX, expctddLdX, rtol=rtol, atol=atol))
        self.assertTrue(
            torch.allclose(actualdLdW, expctddLdW, rtol=rtol, atol=atol))
Exemplo n.º 4
0
    def __init__(self,
                 number_of_channels=[],
                 decimation_factor=[],
                 no_dc_leakage=False,
                 name=''):
        super(NsoltFinalRotation2dLayer, self).__init__()
        self.name = name
        self.number_of_channels = number_of_channels
        self.decimation_factor = decimation_factor
        self.description = "NSOLT final rotation " \
                + "(ps,pa) = (" \
                + str(self.number_of_channels[0]) + "," \
                + str(self.number_of_channels[1]) + "), "  \
                + "(mv,mh) = (" \
                + str(self.decimation_factor[Direction.VERTICAL]) + "," \
                + str(self.decimation_factor[Direction.HORIZONTAL]) + ")"

        # Instantiation of orthormal transforms
        ps, pa = self.number_of_channels
        self.orthTransW0T = OrthonormalTransform(n=ps, mode='Synthesis')
        self.orthTransW0T.angles = nn.init.zeros_(self.orthTransW0T.angles)
        self.orthTransU0T = OrthonormalTransform(n=pa, mode='Synthesis')
        self.orthTransU0T.angles = nn.init.zeros_(self.orthTransU0T.angles)

        # No DC leakage
        self.no_dc_leakage = no_dc_leakage
Exemplo n.º 5
0
    def testBackward8x8RandAngPdAng4(self, mode, ncols):
        datatype = torch.double
        rtol, atol = 1e-4, 1e-7

        # Configuration
        #mode = 'Synthesis'
        nPoints = 8
        #ncols = 2
        angs0 = 2 * math.pi * torch.rand(28, dtype=datatype)
        angs1 = angs0.clone()
        angs2 = angs0.clone()
        pdAng = 4
        delta = 1e-4
        angs1[pdAng] = angs0[pdAng] - delta / 2.
        angs2[pdAng] = angs0[pdAng] + delta / 2.

        # Expcted values
        X = torch.randn(nPoints, ncols, dtype=datatype, requires_grad=True)
        dLdZ = torch.randn(nPoints, ncols, dtype=datatype)
        omgs = OrthonormalMatrixGenerationSystem(dtype=datatype,
                                                 partial_difference=False)
        R = omgs(angles=angs0, mus=1)
        dRdW = (omgs(angles=angs2, mus=1) - omgs(angles=angs1, mus=1)) / delta
        if mode != 'Synthesis':
            expctddLdX = R.T @ dLdZ  # = dZdX @ dLdZ
            expctddLdW = torch.sum(dLdZ * (dRdW @ X))
        else:
            expctddLdX = R @ dLdZ  # = dZdX @ dLdZ
            expctddLdW = torch.sum(dLdZ * (dRdW.T @ X))

        # Instantiation of target class
        target = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode)
        target.angles.data = angs0

        # Actual values
        torch.autograd.set_detect_anomaly(True)
        Z = target.forward(X)
        target.zero_grad()
        Z.backward(dLdZ)
        actualdLdX = X.grad
        actualdLdW = target.angles.grad[pdAng]

        # Evaluation
        self.assertTrue(
            torch.allclose(actualdLdX, expctddLdX, rtol=rtol, atol=atol))
        self.assertTrue(
            torch.allclose(actualdLdW, expctddLdW, rtol=rtol, atol=atol))
Exemplo n.º 6
0
    def __init__(self,
        number_of_channels=[],
        mode='Synthesis',
        name=''):
        super(NsoltIntermediateRotation2dLayer, self).__init__()
        self.name = name
        self.number_of_channels = number_of_channels
        self.description = mode \
                + " NSOLT intermediate rotation " \
                + "(ps,pa) = (" \
                + str(self.number_of_channels[0]) + "," \
                + str(self.number_of_channels[1]) + ")"

        # Instantiation of orthormal transforms
        ps,pa = self.number_of_channels                
        self.orthTransUn = OrthonormalTransform(n=pa,mode=mode)
        self.orthTransUn.angles = nn.init.zeros_(self.orthTransUn.angles)
Exemplo n.º 7
0
class NsoltIntermediateRotation2dLayer(nn.Module):
    """
    NSOLTINTERMEDIATEROTATION2DLAYER 
    
       コンポーネント別に入力(nComponents):
          nSamples x nRows x nCols x nChs
    
       コンポーネント別に出力(nComponents):
          nSamples x nRows x nCols x nChs
    
    Requirements: Python 3.7.x, PyTorch 1.7.x
    
    Copyright (c) 2020-2021, Shogo MURAMATSU
    
    All rights reserved.
    
    Contact address: Shogo MURAMATSU,
        Faculty of Engineering, Niigata University,
        8050 2-no-cho Ikarashi, Nishi-ku,
        Niigata, 950-2181, JAPAN
    
        http://msiplab.eng.niigata-u.ac.jp/ 
    """

    def __init__(self,
        number_of_channels=[],
        mode='Synthesis',
        name=''):
        super(NsoltIntermediateRotation2dLayer, self).__init__()
        self.name = name
        self.number_of_channels = number_of_channels
        self.description = mode \
                + " NSOLT intermediate rotation " \
                + "(ps,pa) = (" \
                + str(self.number_of_channels[0]) + "," \
                + str(self.number_of_channels[1]) + ")"

        # Instantiation of orthormal transforms
        ps,pa = self.number_of_channels                
        self.orthTransUn = OrthonormalTransform(n=pa,mode=mode)
        self.orthTransUn.angles = nn.init.zeros_(self.orthTransUn.angles)

    def forward(self,X):
        nSamples = X.size(dim=0)
        nrows = X.size(dim=1)
        ncols = X.size(dim=2)
        ps,pa = self.number_of_channels

        # Process
        Z = X.clone()
        Ya = X[:,:,:,ps:].view(-1,pa).T 
        Za = self.orthTransUn.forward(Ya)
        Z[:,:,:,ps:] = Za.T.view(nSamples,nrows,ncols,pa)
        return Z

    @property
    def mode(self):
        return self.orthTransUn.mode
Exemplo n.º 8
0
    def test8x8(self, datatype, ncols, mode):
        rtol, atol = 1e-5, 1e-8

        # Expected values
        expctdNorm = torch.tensor(1., dtype=datatype)

        # Instantiation of target class
        target = OrthonormalTransform(n=8, mode=mode)
        target.angles.data = torch.randn(28, dtype=datatype)

        # Actual values
        unitvec = torch.randn(8, ncols, dtype=datatype)
        unitvec /= unitvec.norm()
        with torch.no_grad():
            actualNorm = target.forward(unitvec).norm()  #.numpy()

        # Evaluation
        message = "actualNorm=%s differs from %s" % (str(actualNorm),
                                                     str(expctdNorm))
        #self.assertTrue(np.isclose(actualNorm,expctdNorm,rtol=rtol,atol=atol),message)
        self.assertTrue(
            torch.isclose(actualNorm, expctdNorm, rtol=rtol, atol=atol),
            message)
Exemplo n.º 9
0
    def testConstructor(self, datatype, ncols):
        rtol, atol = 1e-5, 1e-8

        # Expected values
        X = torch.randn(2, ncols, dtype=datatype)
        expctdZ = X
        expctdNParams = 1
        expctdMode = 'Analysis'

        # Instantiation of target class
        target = OrthonormalTransform()

        # Actual values
        with torch.no_grad():
            actualZ = target.forward(X)
        actualNParams = len(target.parameters().__next__())
        actualMode = target.mode

        # Evaluation
        self.assertTrue(isinstance(target, nn.Module))
        self.assertTrue(torch.allclose(actualZ, expctdZ, rtol=rtol, atol=atol))
        self.assertEqual(actualNParams, expctdNParams)
        self.assertEqual(actualMode, expctdMode)
Exemplo n.º 10
0
    def testBackwardAngsAndMus(self, datatype, mode, ncols):
        rtol, atol = 1e-4, 1e-7

        # Configuration
        #mode = 'Analysis'
        nPoints = 2
        #ncols = 1
        mus = [1, -1]

        # Expected values
        X = torch.randn(nPoints, ncols, dtype=datatype, requires_grad=True)
        dLdZ = torch.randn(nPoints, ncols, dtype=datatype)
        # angle = 2.*math.pi*randn(1)
        angle = 2. * math.pi * gauss(mu=0., sigma=1.)  #randn(1)
        R = torch.tensor([[math.cos(angle), -math.sin(angle)],
                          [-math.sin(angle), -math.cos(angle)]],
                         dtype=datatype)  #.squeeze()
        dRdW = torch.tensor(
            [[-math.sin(angle), -math.cos(angle)],
             [-math.cos(angle), math.sin(angle)]],
            dtype=datatype)  #.squeeze()
        if mode != 'Synthesis':
            expctddLdX = R.T @ dLdZ  # = dZdX @ dLdZ
            expctddLdW = torch.sum(dLdZ * (dRdW @ X))
        else:
            expctddLdX = R @ dLdZ  # = dZdX @ dLdZ
            expctddLdW = torch.sum(dLdZ * (dRdW.T @ X))

        # Instantiation of target class
        target = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode)
        target.angles = nn.init.constant_(target.angles, val=angle)
        target.mus = torch.tensor(mus, dtype=datatype)

        # Actual values
        torch.autograd.set_detect_anomaly(True)
        Z = target.forward(X)
        target.zero_grad()
        Z.backward(dLdZ)
        actualdLdX = X.grad
        actualdLdW = target.angles.grad

        # Evaluation
        self.assertTrue(
            torch.allclose(actualdLdX, expctddLdX, rtol=rtol, atol=atol))
        self.assertTrue(
            torch.allclose(actualdLdW, expctddLdW, rtol=rtol, atol=atol))
Exemplo n.º 11
0
    def testGradCheckNxNRandAngMus(self, mode, ncols, npoints):

        # Configuration
        datatype = torch.double
        nPoints = npoints
        nAngs = int(nPoints * (nPoints - 1) / 2.)
        mus = (-1)**torch.randint(high=2, size=(nPoints, ))
        angs = 2. * math.pi * torch.randn(nAngs, dtype=datatype)

        # Expcted values
        X = torch.randn(nPoints, ncols, dtype=datatype, requires_grad=True)
        dLdZ = torch.randn(nPoints, ncols, dtype=datatype)

        # Instantiation of target class
        target = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode)
        target.angles.data = angs
        target.mus = mus
        torch.autograd.set_detect_anomaly(True)
        Z = target.forward(X)
        target.zero_grad()

        # Evaluation
        self.assertTrue(torch.autograd.gradcheck(target, (X, )))
Exemplo n.º 12
0
    def testInstantiationWithInvalidMode(self):
        mode = 'Invalid'

        # Instantiation of target class
        with self.assertRaises(InvalidMode):
            target = OrthonormalTransform(mode=mode)
Exemplo n.º 13
0
    def testBackword8x8RandAngMusPdAng7(self, mode, ncols):
        datatype = torch.double
        rtol, atol = 1e-4, 1e-7

        # Configuration
        #mode = 'Synthesis'
        nPoints = 8
        #ncols = 2
        mus = [1, -1, 1, -1, 1, -1, 1, -1]
        angs0 = 2 * math.pi * torch.rand(28, dtype=datatype)
        angs1 = angs0.clone()
        angs2 = angs0.clone()
        pdAng = 7
        delta = 1e-4
        angs1[pdAng] = angs0[pdAng] - delta / 2.
        angs2[pdAng] = angs0[pdAng] + delta / 2.

        # Expcted values
        X = torch.randn(nPoints, ncols, dtype=datatype, requires_grad=False)
        dLdZ = torch.randn(nPoints, ncols, dtype=datatype)

        # Instantiation of target class
        target0 = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode)
        target0.angles.data = angs0
        target0.mus = mus
        target1 = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode)
        target1.angles.data = angs1
        target1.mus = mus
        target2 = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode)
        target2.angles.data = angs2
        target2.mus = mus

        # Expctd values
        if mode == 'Analysis':
            bwmode = 'Synthesis'
        else:
            bwmode = 'Analysis'
        backprop = OrthonormalTransform(n=nPoints, dtype=datatype, mode=bwmode)
        backprop.angles.data = angs0
        backprop.mus = mus
        torch.autograd.set_detect_anomaly(True)
        dZdW = (target2.forward(X) - target1.forward(X)) / delta  # ~ d(R*X)/dW
        expctddLdW = torch.sum(dLdZ * dZdW)  # ~ dLdW

        # Actual values
        X.detach()
        X.requires_grad = True
        Z = target0.forward(X)
        target0.zero_grad()
        #print(torch.autograd.gradcheck(target0,(X,angs0)))
        Z.backward(dLdZ)
        actualdLdW = target0.angles.grad[pdAng]

        # Evaluation
        self.assertTrue(
            torch.allclose(actualdLdW, expctddLdW, rtol=rtol, atol=atol))
Exemplo n.º 14
0
    def testBackward4x4RandAngPdAng1(self, mode, ncols):
        datatype = torch.double
        rtol, atol = 1e-2, 1e-5

        # Configuration
        #mode = 'Synthesis'
        nPoints = 4
        #ncols = 2
        mus = [-1, -1, -1, -1]
        angs = 2. * math.pi * torch.randn(6, dtype=datatype)
        pdAng = 1
        delta = 1e-4

        # Expcted values
        X = torch.randn(nPoints, ncols, dtype=datatype, requires_grad=True)
        dLdZ = torch.randn(nPoints, ncols, dtype=datatype)
        R = torch.as_tensor(
            torch.tensor(mus).view(-1,1) * \
            torch.tensor(
                [ [1, 0, 0, 0. ],
                 [0, 1, 0, 0. ],
                 [0, 0, math.cos(angs[5]), -math.sin(angs[5]) ],
                 [0, 0, math.sin(angs[5]), math.cos(angs[5]) ] ]
            ) @ torch.tensor(
                [ [1, 0, 0, 0 ],
                 [0, math.cos(angs[4]), 0, -math.sin(angs[4]) ],
                 [0, 0, 1, 0 ],
                 [0, math.sin(angs[4]), 0, math.cos(angs[4]) ] ]
            ) @ torch.tensor(
                [ [1, 0, 0, 0 ],
                 [0, math.cos(angs[3]), -math.sin(angs[3]), 0 ],
                 [0, math.sin(angs[3]), math.cos(angs[3]), 0 ],
                 [0, 0, 0, 1 ] ]
            ) @ torch.tensor(
                [ [ math.cos(angs[2]), 0, 0, -math.sin(angs[2]) ],
                 [0, 1, 0, 0 ],
                 [0, 0, 1, 0 ],
                 [ math.sin(angs[2]), 0, 0, math.cos(angs[2]) ] ]
            ) @ torch.tensor(
               [ [math.cos(angs[1]), 0, -math.sin(angs[1]), 0 ],
                 [0, 1, 0, 0 ],
                 [math.sin(angs[1]), 0, math.cos(angs[1]), 0 ],
                 [0, 0, 0, 1 ] ]
            ) @ torch.tensor(
               [ [ math.cos(angs[0]), -math.sin(angs[0]), 0, 0 ],
                 [ math.sin(angs[0]), math.cos(angs[0]), 0, 0 ],
                 [ 0, 0, 1, 0 ],
                 [ 0, 0, 0, 1 ] ]
            ),dtype=datatype)
        dRdW = torch.as_tensor(
            (1./delta) * torch.tensor(mus).view(-1,1) * \
            torch.tensor(
                [ [1, 0, 0, 0. ],
                 [0, 1, 0., 0. ],
                 [0, 0, math.cos(angs[5]), -math.sin(angs[5]) ],
                 [0., 0, math.sin(angs[5]), math.cos(angs[5]) ] ]
            ) @ torch.tensor(
                [ [1, 0, 0, 0 ],
                 [0, math.cos(angs[4]), 0, -math.sin(angs[4]) ],
                 [0, 0, 1, 0 ],
                 [0, math.sin(angs[4]), 0, math.cos(angs[4]) ] ]
            ) @ torch.tensor(
                [ [1, 0, 0, 0 ],
                 [0, math.cos(angs[3]), -math.sin(angs[3]), 0 ],
                 [0, math.sin(angs[3]), math.cos(angs[3]), 0 ],
                 [0, 0, 0, 1 ] ]
            ) @ torch.tensor(
                [ [ math.cos(angs[2]), 0, 0, -math.sin(angs[2]) ],
                 [0, 1, 0, 0 ],
                 [0, 0, 1, 0 ],
                 [ math.sin(angs[2]), 0, 0, math.cos(angs[2]) ] ]
            ) @ (
                torch.tensor(
               [ [math.cos(angs[1]+delta/2.), 0, -math.sin(angs[1]+delta/2.), 0 ],
                 [0, 1, 0, 0 ],
                 [math.sin(angs[1]+delta/2.), 0, math.cos(angs[1]+delta/2.), 0 ],
                 [0, 0, 0, 1 ] ] ) - \
                torch.tensor(
               [ [math.cos(angs[1]-delta/2.), 0, -math.sin(angs[1]-delta/2.), 0 ],
                 [0, 1, 0, 0 ],
                 [math.sin(angs[1]-delta/2.), 0, math.cos(angs[1]-delta/2.), 0 ],
                 [0, 0, 0, 1 ] ] )
            ) @ torch.tensor(
               [ [ math.cos(angs[0]), -math.sin(angs[0]), 0, 0 ],
                 [ math.sin(angs[0]), math.cos(angs[0]), 0, 0 ],
                 [ 0, 0, 1, 0 ],
                 [ 0, 0, 0, 1 ] ]
            ),dtype=datatype)
        if mode != 'Synthesis':
            expctddLdX = R.T @ dLdZ  # = dZdX @ dLdZ
            expctddLdW = torch.sum(dLdZ * (dRdW @ X))
        else:
            expctddLdX = R @ dLdZ  # = dZdX @ dLdZ
            expctddLdW = torch.sum(dLdZ * (dRdW.T @ X))

        # Instantiation of target class
        target = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode)
        target.angles.data = angs
        target.mus = mus

        # Actual values
        torch.autograd.set_detect_anomaly(True)
        Z = target.forward(X)
        target.zero_grad()
        Z.backward(dLdZ)
        actualdLdX = X.grad
        actualdLdW = target.angles.grad[pdAng]

        # Evaluation
        self.assertTrue(
            torch.allclose(actualdLdX, expctddLdX, rtol=rtol, atol=atol))
        self.assertTrue(
            torch.allclose(actualdLdW, expctddLdW, rtol=rtol, atol=atol))
Exemplo n.º 15
0
    def testForward4x4RandAngs(self, datatype, mode, ncols):
        rtol, atol = 1e-4, 1e-7

        # Configuration
        #mode = 'Synthesis'
        nPoints = 4
        #ncols = 2
        mus = [-1, 1, -1, 1]
        angs = 2. * math.pi * torch.randn(6, dtype=datatype)

        # Expcted values
        X = torch.randn(nPoints, ncols, dtype=datatype)
        R = torch.as_tensor(
            torch.tensor(mus).view(-1,1) * \
            torch.tensor(
                [ [1, 0, 0, 0. ],
                 [0, 1, 0, 0. ],
                 [0, 0, math.cos(angs[5]), -math.sin(angs[5]) ],
                 [0, 0, math.sin(angs[5]), math.cos(angs[5]) ] ]
            ) @ torch.tensor(
                [ [1, 0, 0, 0 ],
                 [0, math.cos(angs[4]), 0, -math.sin(angs[4]) ],
                 [0, 0, 1, 0 ],
                 [0, math.sin(angs[4]), 0, math.cos(angs[4]) ] ]
            ) @ torch.tensor(
                [ [1, 0, 0, 0 ],
                 [0, math.cos(angs[3]), -math.sin(angs[3]), 0 ],
                 [0, math.sin(angs[3]), math.cos(angs[3]), 0 ],
                 [0, 0, 0, 1 ] ]
            ) @ torch.tensor(
                [ [ math.cos(angs[2]), 0, 0, -math.sin(angs[2]) ],
                 [0, 1, 0, 0 ],
                 [0, 0, 1, 0 ],
                 [ math.sin(angs[2]), 0, 0, math.cos(angs[2]) ] ]
            ) @ torch.tensor(
               [ [math.cos(angs[1]), 0, -math.sin(angs[1]), 0 ],
                 [0, 1, 0, 0 ],
                 [math.sin(angs[1]), 0, math.cos(angs[1]), 0 ],
                 [0, 0, 0, 1 ] ]
            ) @ torch.tensor(
               [ [ math.cos(angs[0]), -math.sin(angs[0]), 0, 0 ],
                 [ math.sin(angs[0]), math.cos(angs[0]), 0, 0 ],
                 [ 0, 0, 1, 0 ],
                 [ 0, 0, 0, 1 ] ]
            ),dtype=datatype)
        if mode != 'Synthesis':
            expctdZ = R @ X
        else:
            expctdZ = R.T @ X

        # Instantiation of target class
        target = OrthonormalTransform(n=nPoints, dtype=datatype, mode=mode)
        target.angles.data = angs
        target.mus = mus

        # Actual values
        with torch.no_grad():
            actualZ = target.forward(X)

        # Evaluation
        self.assertTrue(torch.allclose(actualZ, expctdZ, rtol=rtol, atol=atol))
Exemplo n.º 16
0
 def testSetInvalidMus(self):
     mus = 2
     with self.assertRaises(InvalidMus):
         target = OrthonormalTransform()
         target.mus = mus
Exemplo n.º 17
0
 def testInstantiationWithInvalidMus(self):
     mus = 2
     with self.assertRaises(InvalidMus):
         target = OrthonormalTransform(mus=mus)
Exemplo n.º 18
0
 def testSetInvalidMode(self):
     mode = 'Invalid'
     with self.assertRaises(InvalidMode):
         target = OrthonormalTransform()
         target.mode = 'InvalidMode'
Exemplo n.º 19
0
class NsoltFinalRotation2dLayer(nn.Module):
    """
    NSOLTFINALROTATION2DLAYER 
    
       コンポーネント別に入力(nComponents):
          nSamples x nRows x nCols x nChs
    
       コンポーネント別に出力(nComponents):
          nSamples x nRows x nCols x nDecs
    
    Requirements: Python 3.7.x, PyTorch 1.7.x
    
    Copyright (c) 2020-2021, Shogo MURAMATSU
    
    All rights reserved.
    
    Contact address: Shogo MURAMATSU,
        Faculty of Engineering, Niigata University,
        8050 2-no-cho Ikarashi, Nishi-ku,
        Niigata, 950-2181, JAPAN
    
        http://msiplab.eng.niigata-u.ac.jp/ 
    """
    def __init__(self,
                 number_of_channels=[],
                 decimation_factor=[],
                 no_dc_leakage=False,
                 name=''):
        super(NsoltFinalRotation2dLayer, self).__init__()
        self.name = name
        self.number_of_channels = number_of_channels
        self.decimation_factor = decimation_factor
        self.description = "NSOLT final rotation " \
                + "(ps,pa) = (" \
                + str(self.number_of_channels[0]) + "," \
                + str(self.number_of_channels[1]) + "), "  \
                + "(mv,mh) = (" \
                + str(self.decimation_factor[Direction.VERTICAL]) + "," \
                + str(self.decimation_factor[Direction.HORIZONTAL]) + ")"

        # Instantiation of orthormal transforms
        ps, pa = self.number_of_channels
        self.orthTransW0T = OrthonormalTransform(n=ps, mode='Synthesis')
        self.orthTransW0T.angles = nn.init.zeros_(self.orthTransW0T.angles)
        self.orthTransU0T = OrthonormalTransform(n=pa, mode='Synthesis')
        self.orthTransU0T.angles = nn.init.zeros_(self.orthTransU0T.angles)

        # No DC leakage
        self.no_dc_leakage = no_dc_leakage

    def forward(self, X):
        nSamples = X.size(dim=0)
        nrows = X.size(dim=1)
        ncols = X.size(dim=2)
        ps, pa = self.number_of_channels
        stride = self.decimation_factor
        nDecs = stride[0] * stride[1]  # math.prod(stride)

        # No DC leackage
        if self.no_dc_leakage:
            self.orthTransW0T.mus[0] = 1
            self.orthTransW0T.angles.data[:ps-1] = \
                torch.zeros(ps-1,dtype=self.orthTransW0T.angles.data.dtype)

        # Process
        Ys = X[:, :, :, :ps].view(-1, ps).T
        Ya = X[:, :, :, ps:].view(-1, pa).T
        ms = int(math.ceil(nDecs / 2.))
        ma = int(math.floor(nDecs / 2.))
        Zsa = torch.cat((self.orthTransW0T.forward(Ys)[:ms, :],
                         self.orthTransU0T.forward(Ya)[:ma, :]),
                        dim=0)
        return Zsa.T.view(nSamples, nrows, ncols, nDecs)