def test_log_prob(self): batch_size = 10 input_shape = [2, 3, 4] context_shape = [5, 6] flow = base.Flow( transform=transforms.AffineScalarTransform(scale=2.0), distribution=distributions.StandardNormal(input_shape), ) inputs = torch.randn(batch_size, *input_shape) maybe_context = torch.randn(batch_size, *context_shape) for context in [None, maybe_context]: with self.subTest(context=context): log_prob = flow.log_prob(inputs, context=context) self.assertIsInstance(log_prob, torch.Tensor) self.assertEqual(log_prob.shape, torch.Size([batch_size]))
def test_transform_to_noise(self): batch_size = 10 context_size = 20 shape = [2, 3, 4] context_shape = [5, 6] flow = base.Flow( transform=transforms.AffineScalarTransform(scale=2.0), distribution=distributions.StandardNormal(shape), ) inputs = torch.randn(batch_size, *shape) maybe_context = torch.randn(context_size, *context_shape) for context in [None, maybe_context]: with self.subTest(context=context): noise = flow.transform_to_noise(inputs, context=context) self.assertIsInstance(noise, torch.Tensor) self.assertEqual(noise.shape, torch.Size([batch_size] + shape))
def test_sample_and_log_prob(self): num_samples = 10 input_shape = [2, 3, 4] flow = base.Flow( transform=transforms.AffineScalarTransform(scale=2.0), distribution=distributions.StandardNormal(input_shape), ) samples, log_prob_1 = flow.sample_and_log_prob(num_samples) log_prob_2 = flow.log_prob(samples) self.assertIsInstance(samples, torch.Tensor) self.assertIsInstance(log_prob_1, torch.Tensor) self.assertIsInstance(log_prob_2, torch.Tensor) self.assertEqual(samples.shape, torch.Size([num_samples] + input_shape)) self.assertEqual(log_prob_1.shape, torch.Size([num_samples])) self.assertEqual(log_prob_2.shape, torch.Size([num_samples])) self.assertEqual(log_prob_1, log_prob_2)
def test_sample_and_log_prob_with_context(self): num_samples = 10 context_size = 20 input_shape = [2, 3, 4] context_shape = [5, 6] flow = base.Flow( transform=transforms.AffineScalarTransform(scale=2.0), distribution=distributions.StandardNormal(input_shape), ) context = torch.randn(context_size, *context_shape) samples, log_prob = flow.sample_and_log_prob(num_samples, context=context) self.assertIsInstance(samples, torch.Tensor) self.assertIsInstance(log_prob, torch.Tensor) self.assertEqual(samples.shape, torch.Size([context_size, num_samples] + input_shape)) self.assertEqual(log_prob.shape, torch.Size([context_size, num_samples]))
def test_sample(self): num_samples = 10 context_size = 20 input_shape = [2, 3, 4] context_shape = [5, 6] flow = base.Flow( transform=transforms.AffineScalarTransform(scale=2.0), distribution=distributions.StandardNormal(input_shape), ) maybe_context = torch.randn(context_size, *context_shape) for context in [None, maybe_context]: with self.subTest(context=context): samples = flow.sample(num_samples, context=context) self.assertIsInstance(samples, torch.Tensor) if context is None: self.assertEqual(samples.shape, torch.Size([num_samples] + input_shape)) else: self.assertEqual( samples.shape, torch.Size([context_size, num_samples] + input_shape))
def create_transform(c, h, w, levels, hidden_channels, steps_per_level, alpha, num_bits, preprocessing, multi_scale): if not isinstance(hidden_channels, list): hidden_channels = [hidden_channels] * levels if multi_scale: mct = transforms.MultiscaleCompositeTransform(num_transforms=levels) for level, level_hidden_channels in zip(range(levels), hidden_channels): squeeze_transform = transforms.SqueezeTransform() c, h, w = squeeze_transform.get_output_shape(c, h, w) transform_level = transforms.CompositeTransform( [squeeze_transform] + [ create_transform_step(c, level_hidden_channels) for _ in range(steps_per_level) ] + [transforms.OneByOneConvolution(c) ] # End each level with a linear transformation. ) new_shape = mct.add_transform(transform_level, (c, h, w)) if new_shape: # If not last layer c, h, w = new_shape else: all_transforms = [] for level, level_hidden_channels in zip(range(levels), hidden_channels): squeeze_transform = transforms.SqueezeTransform() c, h, w = squeeze_transform.get_output_shape(c, h, w) transform_level = transforms.CompositeTransform( [squeeze_transform] + [ create_transform_step(c, level_hidden_channels) for _ in range(steps_per_level) ] + [transforms.OneByOneConvolution(c) ] # End each level with a linear transformation. ) all_transforms.append(transform_level) all_transforms.append( transforms.ReshapeTransform(input_shape=(c, h, w), output_shape=(c * h * w, ))) mct = transforms.CompositeTransform(all_transforms) # Inputs to the model in [0, 2 ** num_bits] if preprocessing == 'glow': # Map to [-0.5,0.5] preprocess_transform = transforms.AffineScalarTransform( scale=(1. / 2**num_bits), shift=-0.5) elif preprocessing == 'realnvp': preprocess_transform = transforms.CompositeTransform([ # Map to [0,1] transforms.AffineScalarTransform(scale=(1. / 2**num_bits)), # Map into unconstrained space as done in RealNVP transforms.AffineScalarTransform(shift=alpha, scale=(1 - alpha)), transforms.Logit() ]) elif preprocessing == 'realnvp_2alpha': preprocess_transform = transforms.CompositeTransform([ transforms.AffineScalarTransform(scale=(1. / 2**num_bits)), transforms.AffineScalarTransform(shift=alpha, scale=(1 - 2. * alpha)), transforms.Logit() ]) else: raise RuntimeError( 'Unknown preprocessing type: {}'.format(preprocessing)) return transforms.CompositeTransform([preprocess_transform, mct])