Example #1
0
    def test_layer_is_well_behaved(self):
        batch_size = 10
        shape = (6, 8, 8)
        x = torch.randn(batch_size, *shape)

        module = MaskedConv2d(6, 6, kernel_size=3, padding=1, mask_type='A')
        self.assert_layer_is_well_behaved(module, x)

        module = MaskedConv2d(6, 6, kernel_size=3, padding=1, mask_type='B')
        self.assert_layer_is_well_behaved(module, x)
Example #2
0
    def test_autoregressive_type_A(self):
        batch_size = 10
        shape = (6, 8, 8)
        x = torch.randn(batch_size, *shape)
        x_altered = copy.deepcopy(x)
        x_altered[:, 4, 4,
                  2] += 100.0  # Alter channel G of feature 2/2 in position (4,2)

        module = MaskedConv2d(6, 6, kernel_size=3, padding=1, mask_type='A')
        y = module(x)
        y_altered = module(x_altered)

        # Assert all pixels up to (4,2) are unaltered
        self.assertEqual(y[:, :, :4], y_altered[:, :, :4])
        self.assertEqual(y[:, :, 4, :2], y_altered[:, :, 4, :2])

        # Assert channel R is unaltered
        self.assertEqual(y[:, 0, 4, 2], y_altered[:, 0, 4, 2])
        self.assertEqual(y[:, 3, 4, 2], y_altered[:, 3, 4, 2])

        # Assert channel G is unaltered
        self.assertEqual(y[:, 1, 4, 2], y_altered[:, 1, 4, 2])
        self.assertEqual(y[:, 4, 4, 2], y_altered[:, 4, 4, 2])

        # Assert channel B is altered
        self.assertFalse((y[:, 2, 4, 2] == y_altered[:, 2, 4,
                                                     2]).view(-1).any())
        self.assertFalse((y[:, 5, 4, 2] == y_altered[:, 5, 4,
                                                     2]).view(-1).any())

        # Assert all elements in next pixel are altered
        self.assertFalse((y[:, :, 4, 3] == y_altered[:, :, 4,
                                                     3]).view(-1).any())
Example #3
0
    def test_bijection_is_well_behaved(self):
        batch_size = 10
        shape = (3, 8, 8)
        x = torch.randn(batch_size, *shape)
        net_spatial = nn.Sequential(
            SpatialMaskedConv2d(3,
                                3 * 2,
                                kernel_size=3,
                                padding=1,
                                mask_type='A'), ElementwiseParams2d(2))
        net = nn.Sequential(
            MaskedConv2d(3, 3 * 2, kernel_size=3, padding=1, mask_type='A'),
            ElementwiseParams2d(2))

        self.eps = 1e-6
        for autoregressive_order in ['raster_cwh', 'raster_wh']:
            with self.subTest(autoregressive_order=autoregressive_order):
                if autoregressive_order == 'raster_cwh': autoreg_net = net
                elif autoregressive_order == 'raster_wh':
                    autoreg_net = net_spatial
                bijection = AffineAutoregressiveBijection2d(
                    autoreg_net, autoregressive_order=autoregressive_order)
                self.assert_bijection_is_well_behaved(bijection,
                                                      x,
                                                      z_shape=(batch_size,
                                                               *shape))
Example #4
0
    def __init__(self,
                 in_channels,
                 num_params,
                 filters=128,
                 num_blocks=15,
                 output_filters=1024,
                 kernel_size=3,
                 kernel_size_in=7,
                 init_transforms=lambda x: 2 * x - 1):

        layers = [LambdaLayer(init_transforms)] +\
                 [MaskedConv2d(in_channels, 2 * filters, kernel_size=kernel_size_in, padding=kernel_size_in//2, mask_type='A', data_channels=in_channels)] +\
                 [MaskedResidualBlock2d(filters, data_channels=in_channels, kernel_size=kernel_size) for _ in range(num_blocks)] +\
                 [nn.ReLU(True), MaskedConv2d(2 * filters, output_filters, kernel_size=1, mask_type='B', data_channels=in_channels)] +\
                 [nn.ReLU(True), MaskedConv2d(output_filters, num_params * in_channels, kernel_size=1, mask_type='B', data_channels=in_channels)] +\
                 [ElementwiseParams2d(num_params)]

        super(PixelCNN, self).__init__(*layers)
Example #5
0
    def __init__(self, h, kernel_size=3, data_channels=3):
        super(MaskedResidualBlock2d, self).__init__()

        self.conv1 = MaskedConv2d(2 * h,
                                  h,
                                  kernel_size=1,
                                  mask_type='B',
                                  data_channels=data_channels)
        self.conv2 = MaskedConv2d(h,
                                  h,
                                  kernel_size=kernel_size,
                                  padding=kernel_size // 2,
                                  mask_type='B',
                                  data_channels=data_channels)
        self.conv3 = MaskedConv2d(h,
                                  2 * h,
                                  kernel_size=1,
                                  mask_type='B',
                                  data_channels=data_channels)
    def test_bijection_is_well_behaved(self):
        num_mix = 8
        batch_size = 10
        shape = (3,4,4)
        elementwise_params = 3 * num_mix
        x = torch.randn(batch_size, *shape)
        net_spatial = nn.Sequential(SpatialMaskedConv2d(3,3*elementwise_params, kernel_size=3, padding=1, mask_type='A'), ElementwiseParams2d(elementwise_params))
        net = nn.Sequential(MaskedConv2d(3,3*elementwise_params, kernel_size=3, padding=1, mask_type='A'), ElementwiseParams2d(elementwise_params))

        self.eps = 5e-5
        bijection = LogisticMixtureAutoregressiveBijection2d(net, num_mixtures=num_mix, autoregressive_order='raster_cwh')
        self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape))

        bijection = LogisticMixtureAutoregressiveBijection2d(net_spatial, num_mixtures=num_mix, autoregressive_order='raster_wh')
        self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape))
 def __init__(self, channels, context_channels, params,
              autoregressive_order):
     super(CondNet, self).__init__()
     if autoregressive_order == 'raster_cwh':
         self.conv = MaskedConv2d(channels,
                                  channels * params,
                                  kernel_size=3,
                                  padding=1,
                                  mask_type='A')
     elif autoregressive_order == 'raster_wh':
         self.conv = SpatialMaskedConv2d(channels,
                                         channels * params,
                                         kernel_size=3,
                                         padding=1,
                                         mask_type='A')
     self.context = nn.Conv2d(context_channels,
                              channels * params,
                              kernel_size=1)
     self.out = ElementwiseParams2d(params)
Example #8
0
    def test_bijection_is_well_behaved(self):
        batch_size = 10
        shape = (3, 8, 8)
        x = torch.randn(batch_size, *shape)
        net_spatial = SpatialMaskedConv2d(3,
                                          3,
                                          kernel_size=3,
                                          padding=1,
                                          mask_type='A')
        net = MaskedConv2d(3, 3, kernel_size=3, padding=1, mask_type='A')

        self.eps = 1e-6
        bijection = AdditiveAutoregressiveBijection2d(
            net, autoregressive_order='raster_cwh')
        self.assert_bijection_is_well_behaved(bijection,
                                              x,
                                              z_shape=(batch_size, *shape))

        bijection = AdditiveAutoregressiveBijection2d(
            net_spatial, autoregressive_order='raster_wh')
        self.assert_bijection_is_well_behaved(bijection,
                                              x,
                                              z_shape=(batch_size, *shape))