Esempio n. 1
0
 def test_log_prob(self):
     batch_size = 10
     input_shape = [2, 3, 4]
     context_shape = [5, 6]
     flow = base.Flow(
         transform=transforms.AffineScalarTransform(scale=2.0),
         distribution=distributions.StandardNormal(input_shape),
     )
     inputs = torch.randn(batch_size, *input_shape)
     maybe_context = torch.randn(batch_size, *context_shape)
     for context in [None, maybe_context]:
         with self.subTest(context=context):
             log_prob = flow.log_prob(inputs, context=context)
             self.assertIsInstance(log_prob, torch.Tensor)
             self.assertEqual(log_prob.shape, torch.Size([batch_size]))
Esempio n. 2
0
 def test_transform_to_noise(self):
     batch_size = 10
     context_size = 20
     shape = [2, 3, 4]
     context_shape = [5, 6]
     flow = base.Flow(
         transform=transforms.AffineScalarTransform(scale=2.0),
         distribution=distributions.StandardNormal(shape),
     )
     inputs = torch.randn(batch_size, *shape)
     maybe_context = torch.randn(context_size, *context_shape)
     for context in [None, maybe_context]:
         with self.subTest(context=context):
             noise = flow.transform_to_noise(inputs, context=context)
             self.assertIsInstance(noise, torch.Tensor)
             self.assertEqual(noise.shape, torch.Size([batch_size] + shape))
Esempio n. 3
0
 def test_sample_and_log_prob(self):
     num_samples = 10
     input_shape = [2, 3, 4]
     flow = base.Flow(
         transform=transforms.AffineScalarTransform(scale=2.0),
         distribution=distributions.StandardNormal(input_shape),
     )
     samples, log_prob_1 = flow.sample_and_log_prob(num_samples)
     log_prob_2 = flow.log_prob(samples)
     self.assertIsInstance(samples, torch.Tensor)
     self.assertIsInstance(log_prob_1, torch.Tensor)
     self.assertIsInstance(log_prob_2, torch.Tensor)
     self.assertEqual(samples.shape,
                      torch.Size([num_samples] + input_shape))
     self.assertEqual(log_prob_1.shape, torch.Size([num_samples]))
     self.assertEqual(log_prob_2.shape, torch.Size([num_samples]))
     self.assertEqual(log_prob_1, log_prob_2)
Esempio n. 4
0
 def test_sample_and_log_prob_with_context(self):
     num_samples = 10
     context_size = 20
     input_shape = [2, 3, 4]
     context_shape = [5, 6]
     flow = base.Flow(
         transform=transforms.AffineScalarTransform(scale=2.0),
         distribution=distributions.StandardNormal(input_shape),
     )
     context = torch.randn(context_size, *context_shape)
     samples, log_prob = flow.sample_and_log_prob(num_samples,
                                                  context=context)
     self.assertIsInstance(samples, torch.Tensor)
     self.assertIsInstance(log_prob, torch.Tensor)
     self.assertEqual(samples.shape,
                      torch.Size([context_size, num_samples] + input_shape))
     self.assertEqual(log_prob.shape,
                      torch.Size([context_size, num_samples]))
Esempio n. 5
0
 def test_sample(self):
     num_samples = 10
     context_size = 20
     input_shape = [2, 3, 4]
     context_shape = [5, 6]
     flow = base.Flow(
         transform=transforms.AffineScalarTransform(scale=2.0),
         distribution=distributions.StandardNormal(input_shape),
     )
     maybe_context = torch.randn(context_size, *context_shape)
     for context in [None, maybe_context]:
         with self.subTest(context=context):
             samples = flow.sample(num_samples, context=context)
             self.assertIsInstance(samples, torch.Tensor)
             if context is None:
                 self.assertEqual(samples.shape,
                                  torch.Size([num_samples] + input_shape))
             else:
                 self.assertEqual(
                     samples.shape,
                     torch.Size([context_size, num_samples] + input_shape))
Esempio n. 6
0
def create_transform(c, h, w, levels, hidden_channels, steps_per_level, alpha,
                     num_bits, preprocessing, multi_scale):
    if not isinstance(hidden_channels, list):
        hidden_channels = [hidden_channels] * levels

    if multi_scale:
        mct = transforms.MultiscaleCompositeTransform(num_transforms=levels)
        for level, level_hidden_channels in zip(range(levels),
                                                hidden_channels):
            squeeze_transform = transforms.SqueezeTransform()
            c, h, w = squeeze_transform.get_output_shape(c, h, w)

            transform_level = transforms.CompositeTransform(
                [squeeze_transform] + [
                    create_transform_step(c, level_hidden_channels)
                    for _ in range(steps_per_level)
                ] + [transforms.OneByOneConvolution(c)
                     ]  # End each level with a linear transformation.
            )

            new_shape = mct.add_transform(transform_level, (c, h, w))
            if new_shape:  # If not last layer
                c, h, w = new_shape
    else:
        all_transforms = []

        for level, level_hidden_channels in zip(range(levels),
                                                hidden_channels):
            squeeze_transform = transforms.SqueezeTransform()
            c, h, w = squeeze_transform.get_output_shape(c, h, w)

            transform_level = transforms.CompositeTransform(
                [squeeze_transform] + [
                    create_transform_step(c, level_hidden_channels)
                    for _ in range(steps_per_level)
                ] + [transforms.OneByOneConvolution(c)
                     ]  # End each level with a linear transformation.
            )
            all_transforms.append(transform_level)

        all_transforms.append(
            transforms.ReshapeTransform(input_shape=(c, h, w),
                                        output_shape=(c * h * w, )))
        mct = transforms.CompositeTransform(all_transforms)

    # Inputs to the model in [0, 2 ** num_bits]

    if preprocessing == 'glow':
        # Map to [-0.5,0.5]
        preprocess_transform = transforms.AffineScalarTransform(
            scale=(1. / 2**num_bits), shift=-0.5)
    elif preprocessing == 'realnvp':
        preprocess_transform = transforms.CompositeTransform([
            # Map to [0,1]
            transforms.AffineScalarTransform(scale=(1. / 2**num_bits)),
            # Map into unconstrained space as done in RealNVP
            transforms.AffineScalarTransform(shift=alpha, scale=(1 - alpha)),
            transforms.Logit()
        ])

    elif preprocessing == 'realnvp_2alpha':
        preprocess_transform = transforms.CompositeTransform([
            transforms.AffineScalarTransform(scale=(1. / 2**num_bits)),
            transforms.AffineScalarTransform(shift=alpha,
                                             scale=(1 - 2. * alpha)),
            transforms.Logit()
        ])
    else:
        raise RuntimeError(
            'Unknown preprocessing type: {}'.format(preprocessing))

    return transforms.CompositeTransform([preprocess_transform, mct])