def test_permutation2d(self): # initial variables x = torch.Tensor(np.random.rand(2, 16, 4, 4)) reverse = Permutation2d(num_channels=16) shuffle = Permutation2d(num_channels=16, shuffle=True) # forward and reverse flow y_reverse = reverse(x) x_reverse = reverse(y_reverse, reverse=True) y_shuffle = shuffle(x) x_shuffle = shuffle(y_shuffle, reverse=True) # assertion self.assertTrue(ops.tensor_equal(x, x_reverse)) self.assertTrue(ops.tensor_equal(x, x_shuffle))
def test_split_channel(self): x = torch.ones(2, 4, 16, 16) nc = x.shape[1] # simple splitting x1, x2 = ops.split_channel(x, 'simple') for c in range(nc // 2): self.assertTrue(ops.tensor_equal(x1[:, c, :, :], x[:, c, :, :])) self.assertTrue( ops.tensor_equal(x2[:, c, :, :], x[:, nc // 2 + c, :, :])) # cross splitting x1, x2 = ops.split_channel(x, 'cross') for c in range(nc // 2): self.assertTrue(ops.tensor_equal(x1[:, c, :, :], x[:, 2 * c, :, :])) self.assertTrue( ops.tensor_equal(x2[:, c, :, :], x[:, 2 * c + 1, :, :]))
def test_squeeze2d(self): # initial variables x = torch.Tensor(np.random.rand(2, 16, 4, 4)) squeeze = Squeeze2d(factor=2) # forward and reverse flow y, _ = squeeze(x) x_, _ = squeeze(y, reverse=True) # assertion self.assertTrue(ops.tensor_equal(x, x_))
def test_actnorm(self): # initial variables x = torch.Tensor(np.random.rand(2, 16, 4, 4)) actnorm = ActNorm(num_channels=16) # forward and reverse flow y, _ = actnorm(x) x_, _ = actnorm(y, reverse=True) # assertion self.assertTrue(ops.tensor_equal(x, x_))
def test_invertible_1x1_conv(self): # initial variables x = torch.Tensor(np.random.rand(2, 16, 4, 4)) invertible_1x1_conv = Invertible1x1Conv(num_channels=16) # forward and reverse flow y, _ = invertible_1x1_conv(x) x_, _ = invertible_1x1_conv(y, reverse=True) # assertion self.assertEqual(x.shape, y.shape) self.assertTrue(ops.tensor_equal(x, x_))
def test_split2d(self): # initial variables x = torch.Tensor(np.random.rand(2, 16, 4, 4)) split2d = Split2d(num_channels=16) # forward and reverse flow y, _ = split2d(x, 0, reverse=False) x_, _ = split2d(y, 0, reverse=True) # assertion self.assertTrue( ops.tensor_equal(x[:, :x.shape[1] // 2, :, :], x_[:, :x_.shape[1] // 2, :, :]))
def test_flow_step(self): flow_permutation = ['invconv', 'reverse', 'shuffle'] flow_coupling = ['additive', 'affine'] for permutation in flow_permutation: for coupling in flow_coupling: # initial variables x = torch.Tensor(np.random.rand(2, 16, 4, 4)) flow_step = FlowStep( in_channels=16, hidden_channels=256, permutation=permutation, coupling=coupling, actnorm_scale=1., lu_decomposition=False ) # forward and reverse flow y, det = flow_step(x, 0, reverse=False) x_, det_ = flow_step(y, det, reverse=True) # assertion self.assertTrue(ops.tensor_equal(x, x_))
def test_tensor_equal(self): x = torch.Tensor(np.random.rand(2, 16, 4, 4)) x_ = x + 1e-4 self.assertTrue(ops.tensor_equal(x, x)) self.assertFalse(ops.tensor_equal(x, x_))
def test_cat_channel(self): x = torch.ones(2, 4, 16, 16) x1, x2 = ops.split_channel(x, 'simple') self.assertTrue(ops.tensor_equal(ops.cat_channel(x1, x2), x))
def test_reduce_sum(self): x = torch.ones(2, 3, 16, 16) sum = ops.reduce_sum(x, dim=[1, 2, 3]) sum_shape = float(x.shape[1] * x.shape[2] * x.shape[3]) self.assertTrue( ops.tensor_equal(torch.Tensor([sum_shape, sum_shape]), sum))
def test_reduce_mean(self): x = torch.ones(2, 3, 16, 16) mean = ops.reduce_mean(x, dim=[1, 2, 3]) self.assertTrue(ops.tensor_equal(torch.ones(2), mean))