def test_forward_inverse_are_consistent(self): features = 100 batch_size = 50 inputs = torch.randn(batch_size, features) transforms = [ norm.BatchNorm(features=features, affine=affine) for affine in [True, False] ] self.eps = 1e-6 for transform in transforms: with self.subTest(transform=transform): transform.eval() self.assert_forward_inverse_are_consistent(transform, inputs)
def test_forward(self): features = 100 batch_size = 50 bn_eps = 1e-5 self.eps = 1e-4 for affine in [True, True]: with self.subTest(affine=affine): inputs = torch.randn(batch_size, features) transform = norm.BatchNorm(features=features, affine=affine, eps=bn_eps) outputs, logabsdet = transform(inputs) self.assert_tensor_is_good(outputs, [batch_size, features]) self.assert_tensor_is_good(logabsdet, [batch_size]) mean, var = inputs.mean(0), inputs.var(0) outputs_ref = (inputs - mean) / torch.sqrt(var + bn_eps) logabsdet_ref = torch.sum( torch.log(1.0 / torch.sqrt(var + bn_eps))) logabsdet_ref = torch.full([batch_size], logabsdet_ref.item()) if affine: outputs_ref *= transform.weight outputs_ref += transform.bias logabsdet_ref += torch.sum(torch.log(transform.weight)) self.assert_tensor_is_good(outputs_ref, [batch_size, features]) self.assert_tensor_is_good(logabsdet_ref, [batch_size]) print(outputs, outputs_ref) self.assertEqual(outputs, outputs_ref) self.assertEqual(logabsdet, logabsdet_ref) transform.eval() outputs, logabsdet = transform(inputs) self.assert_tensor_is_good(outputs, [batch_size, features]) self.assert_tensor_is_good(logabsdet, [batch_size]) mean = transform.running_mean var = transform.running_var outputs_ref = (inputs - mean) / torch.sqrt(var + bn_eps) logabsdet_ref = torch.sum( torch.log(1.0 / torch.sqrt(var + bn_eps))) logabsdet_ref = torch.full([batch_size], logabsdet_ref.item()) if affine: outputs_ref *= transform.weight outputs_ref += transform.bias logabsdet_ref += torch.sum(torch.log(transform.weight)) self.assert_tensor_is_good(outputs_ref, [batch_size, features]) self.assert_tensor_is_good(logabsdet_ref, [batch_size]) self.assertEqual(outputs, outputs_ref) self.assertEqual(logabsdet, logabsdet_ref)
def test_inverse(self): features = 100 batch_size = 50 inputs = torch.randn(batch_size, features) for affine in [True, False]: with self.subTest(affine=affine): transform = norm.BatchNorm(features=features, affine=affine) with self.assertRaises(base.InverseNotAvailable): transform.inverse(inputs) transform.eval() outputs, logabsdet = transform.inverse(inputs) self.assert_tensor_is_good(outputs, [batch_size, features]) self.assert_tensor_is_good(logabsdet, [batch_size])