def create_transform(self, shape, split_dim=1): mct = base.MultiscaleCompositeTransform(num_transforms=4, split_dim=split_dim) for transform in [ standard.AffineScalarTransform(scale=2.), standard.AffineScalarTransform(scale=4.), standard.AffineScalarTransform(scale=0.5), standard.AffineScalarTransform(scale=0.25) ]: shape = mct.add_transform(transform, shape) return mct
def test_inverse(self): batch_size = 10 shape = [2, 3, 4] inputs = torch.randn(batch_size, *shape) transform = base.InverseTransform( standard.AffineScalarTransform(scale=-2.0)) reference = standard.AffineScalarTransform(scale=-0.5) outputs, logabsdet = transform.inverse(inputs) outputs_ref, logabsdet_ref = reference.inverse(inputs) self.assert_tensor_is_good(outputs, [batch_size] + shape) self.assert_tensor_is_good(logabsdet, [batch_size]) self.assertEqual(outputs, outputs_ref) self.assertEqual(logabsdet, logabsdet_ref)
def test_inverse(self): batch_size = 10 shape = [2, 3, 4] inputs = torch.randn(batch_size, *shape) transforms = [ standard.AffineScalarTransform(scale=2.0), standard.IdentityTransform(), standard.AffineScalarTransform(scale=-0.25), ] composite = base.CompositeTransform(transforms) reference = standard.AffineScalarTransform(scale=-0.5) outputs, logabsdet = composite.inverse(inputs) outputs_ref, logabsdet_ref = reference.inverse(inputs) self.assert_tensor_is_good(outputs, [batch_size] + shape) self.assert_tensor_is_good(logabsdet, [batch_size]) self.assertEqual(outputs, outputs_ref) self.assertEqual(logabsdet, logabsdet_ref)
def test_case(scale, shift, true_outputs, true_logabsdet): with self.subTest(scale=scale, shift=shift): transform = standard.AffineScalarTransform(scale=scale, shift=shift) outputs, logabsdet = transform.inverse(inputs) self.assert_tensor_is_good(outputs, [batch_size] + shape) self.assert_tensor_is_good(logabsdet, [batch_size]) self.assertEqual(outputs, true_outputs) self.assertEqual( logabsdet, torch.full([batch_size], true_logabsdet * np.prod(shape)))
def test_case(scale, shift): transform = standard.AffineScalarTransform(scale=scale, shift=shift) self.assert_forward_inverse_are_consistent(transform, inputs)