Beispiel #1
0
    def test_bijection_is_well_behaved(self):
        batch_size = 10
        shape = (3, 8, 8)
        x = torch.randn(batch_size, *shape)
        net_spatial = nn.Sequential(
            SpatialMaskedConv2d(3,
                                3 * 2,
                                kernel_size=3,
                                padding=1,
                                mask_type='A'), ElementwiseParams2d(2))
        net = nn.Sequential(
            MaskedConv2d(3, 3 * 2, kernel_size=3, padding=1, mask_type='A'),
            ElementwiseParams2d(2))

        self.eps = 1e-6
        for autoregressive_order in ['raster_cwh', 'raster_wh']:
            with self.subTest(autoregressive_order=autoregressive_order):
                if autoregressive_order == 'raster_cwh': autoreg_net = net
                elif autoregressive_order == 'raster_wh':
                    autoreg_net = net_spatial
                bijection = AffineAutoregressiveBijection2d(
                    autoreg_net, autoregressive_order=autoregressive_order)
                self.assert_bijection_is_well_behaved(bijection,
                                                      x,
                                                      z_shape=(batch_size,
                                                               *shape))
Beispiel #2
0
    def test_layer_is_well_behaved(self):
        batch_size = 10
        shape = (2, 8, 8)
        x = torch.randn(batch_size, *shape)

        module = SpatialMaskedConv2d(2,
                                     2,
                                     kernel_size=3,
                                     padding=1,
                                     mask_type='A')
        self.assert_layer_is_well_behaved(module, x)

        module = SpatialMaskedConv2d(2,
                                     2,
                                     kernel_size=3,
                                     padding=1,
                                     mask_type='B')
        self.assert_layer_is_well_behaved(module, x)
Beispiel #3
0
    def __init__(self, h, kernel_size=3):
        super(SpatialMaskedResidualBlock2d, self).__init__()

        self.conv1 = nn.Conv2d(2 * h, h, kernel_size=1)
        self.conv2 = SpatialMaskedConv2d(h,
                                         h,
                                         kernel_size=kernel_size,
                                         padding=kernel_size // 2,
                                         mask_type='B')
        self.conv3 = nn.Conv2d(h, 2 * h, kernel_size=1)
    def test_bijection_is_well_behaved(self):
        num_mix = 8
        batch_size = 10
        shape = (3,4,4)
        elementwise_params = 3 * num_mix
        x = torch.randn(batch_size, *shape)
        net_spatial = nn.Sequential(SpatialMaskedConv2d(3,3*elementwise_params, kernel_size=3, padding=1, mask_type='A'), ElementwiseParams2d(elementwise_params))
        net = nn.Sequential(MaskedConv2d(3,3*elementwise_params, kernel_size=3, padding=1, mask_type='A'), ElementwiseParams2d(elementwise_params))

        self.eps = 5e-5
        bijection = LogisticMixtureAutoregressiveBijection2d(net, num_mixtures=num_mix, autoregressive_order='raster_cwh')
        self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape))

        bijection = LogisticMixtureAutoregressiveBijection2d(net_spatial, num_mixtures=num_mix, autoregressive_order='raster_wh')
        self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape))
 def __init__(self, channels, context_channels, params,
              autoregressive_order):
     super(CondNet, self).__init__()
     if autoregressive_order == 'raster_cwh':
         self.conv = MaskedConv2d(channels,
                                  channels * params,
                                  kernel_size=3,
                                  padding=1,
                                  mask_type='A')
     elif autoregressive_order == 'raster_wh':
         self.conv = SpatialMaskedConv2d(channels,
                                         channels * params,
                                         kernel_size=3,
                                         padding=1,
                                         mask_type='A')
     self.context = nn.Conv2d(context_channels,
                              channels * params,
                              kernel_size=1)
     self.out = ElementwiseParams2d(params)
Beispiel #6
0
    def test_autoregressive_type_B(self):
        batch_size = 10
        shape = (2, 8, 8)
        x = torch.randn(batch_size, *shape)
        x_altered = copy.deepcopy(x)
        x_altered[:, 1, 4, 2] += 100.0  # Alter feature 2/2 in position (4,2)

        module = SpatialMaskedConv2d(2,
                                     2,
                                     kernel_size=3,
                                     padding=1,
                                     mask_type='B')
        y = module(x)
        y_altered = module(x_altered)

        # Assert every element up to (but not including) (4,2) is unaltered
        self.assertEqual(y[:, :, :4], y_altered[:, :, :4])
        self.assertEqual(y[:, :, 4, :2], y_altered[:, :, 4, :2])

        # Assert element (4,2) is altered
        self.assertFalse((y[:, :, 4, 2] == y_altered[:, :, 4,
                                                     2]).view(-1).any())
Beispiel #7
0
    def test_bijection_is_well_behaved(self):
        batch_size = 10
        shape = (3, 8, 8)
        x = torch.randn(batch_size, *shape)
        net_spatial = SpatialMaskedConv2d(3,
                                          3,
                                          kernel_size=3,
                                          padding=1,
                                          mask_type='A')
        net = MaskedConv2d(3, 3, kernel_size=3, padding=1, mask_type='A')

        self.eps = 1e-6
        bijection = AdditiveAutoregressiveBijection2d(
            net, autoregressive_order='raster_cwh')
        self.assert_bijection_is_well_behaved(bijection,
                                              x,
                                              z_shape=(batch_size, *shape))

        bijection = AdditiveAutoregressiveBijection2d(
            net_spatial, autoregressive_order='raster_wh')
        self.assert_bijection_is_well_behaved(bijection,
                                              x,
                                              z_shape=(batch_size, *shape))