Exemple #1
0
 def test_permutation2d(self):
     # initial variables
     x = torch.Tensor(np.random.rand(2, 16, 4, 4))
     reverse = Permutation2d(num_channels=16)
     shuffle = Permutation2d(num_channels=16, shuffle=True)
     # forward and reverse flow
     y_reverse = reverse(x)
     x_reverse = reverse(y_reverse, reverse=True)
     y_shuffle = shuffle(x)
     x_shuffle = shuffle(y_shuffle, reverse=True)
     # assertion
     self.assertTrue(ops.tensor_equal(x, x_reverse))
     self.assertTrue(ops.tensor_equal(x, x_shuffle))
Exemple #2
0
 def test_split_channel(self):
     x = torch.ones(2, 4, 16, 16)
     nc = x.shape[1]
     # simple splitting
     x1, x2 = ops.split_channel(x, 'simple')
     for c in range(nc // 2):
         self.assertTrue(ops.tensor_equal(x1[:, c, :, :], x[:, c, :, :]))
         self.assertTrue(
             ops.tensor_equal(x2[:, c, :, :], x[:, nc // 2 + c, :, :]))
     # cross splitting
     x1, x2 = ops.split_channel(x, 'cross')
     for c in range(nc // 2):
         self.assertTrue(ops.tensor_equal(x1[:, c, :, :], x[:,
                                                            2 * c, :, :]))
         self.assertTrue(
             ops.tensor_equal(x2[:, c, :, :], x[:, 2 * c + 1, :, :]))
Exemple #3
0
 def test_squeeze2d(self):
     # initial variables
     x = torch.Tensor(np.random.rand(2, 16, 4, 4))
     squeeze = Squeeze2d(factor=2)
     # forward and reverse flow
     y, _ = squeeze(x)
     x_, _ = squeeze(y, reverse=True)
     # assertion
     self.assertTrue(ops.tensor_equal(x, x_))
Exemple #4
0
 def test_actnorm(self):
     # initial variables
     x = torch.Tensor(np.random.rand(2, 16, 4, 4))
     actnorm = ActNorm(num_channels=16)
     # forward and reverse flow
     y, _ = actnorm(x)
     x_, _ = actnorm(y, reverse=True)
     # assertion
     self.assertTrue(ops.tensor_equal(x, x_))
Exemple #5
0
 def test_invertible_1x1_conv(self):
     # initial variables
     x = torch.Tensor(np.random.rand(2, 16, 4, 4))
     invertible_1x1_conv = Invertible1x1Conv(num_channels=16)
     # forward and reverse flow
     y, _ = invertible_1x1_conv(x)
     x_, _ = invertible_1x1_conv(y, reverse=True)
     # assertion
     self.assertEqual(x.shape, y.shape)
     self.assertTrue(ops.tensor_equal(x, x_))
Exemple #6
0
 def test_split2d(self):
     # initial variables
     x = torch.Tensor(np.random.rand(2, 16, 4, 4))
     split2d = Split2d(num_channels=16)
     # forward and reverse flow
     y, _ = split2d(x, 0, reverse=False)
     x_, _ = split2d(y, 0, reverse=True)
     # assertion
     self.assertTrue(
         ops.tensor_equal(x[:, :x.shape[1] // 2, :, :],
                          x_[:, :x_.shape[1] // 2, :, :]))
Exemple #7
0
    def test_flow_step(self):
        flow_permutation = ['invconv', 'reverse', 'shuffle']
        flow_coupling = ['additive', 'affine']

        for permutation in flow_permutation:
            for coupling in flow_coupling:
                # initial variables
                x = torch.Tensor(np.random.rand(2, 16, 4, 4))
                flow_step = FlowStep(
                    in_channels=16,
                    hidden_channels=256,
                    permutation=permutation,
                    coupling=coupling,
                    actnorm_scale=1.,
                    lu_decomposition=False
                )
                # forward and reverse flow
                y, det = flow_step(x, 0, reverse=False)
                x_, det_ = flow_step(y, det, reverse=True)
                # assertion
                self.assertTrue(ops.tensor_equal(x, x_))
Exemple #8
0
 def test_tensor_equal(self):
     x = torch.Tensor(np.random.rand(2, 16, 4, 4))
     x_ = x + 1e-4
     self.assertTrue(ops.tensor_equal(x, x))
     self.assertFalse(ops.tensor_equal(x, x_))
Exemple #9
0
 def test_cat_channel(self):
     x = torch.ones(2, 4, 16, 16)
     x1, x2 = ops.split_channel(x, 'simple')
     self.assertTrue(ops.tensor_equal(ops.cat_channel(x1, x2), x))
Exemple #10
0
 def test_reduce_sum(self):
     x = torch.ones(2, 3, 16, 16)
     sum = ops.reduce_sum(x, dim=[1, 2, 3])
     sum_shape = float(x.shape[1] * x.shape[2] * x.shape[3])
     self.assertTrue(
         ops.tensor_equal(torch.Tensor([sum_shape, sum_shape]), sum))
Exemple #11
0
 def test_reduce_mean(self):
     x = torch.ones(2, 3, 16, 16)
     mean = ops.reduce_mean(x, dim=[1, 2, 3])
     self.assertTrue(ops.tensor_equal(torch.ones(2), mean))