def test_layer_is_well_behaved(self): batch_size = 10 shape = (6, 8, 8) x = torch.randn(batch_size, *shape) module = MaskedConv2d(6, 6, kernel_size=3, padding=1, mask_type='A') self.assert_layer_is_well_behaved(module, x) module = MaskedConv2d(6, 6, kernel_size=3, padding=1, mask_type='B') self.assert_layer_is_well_behaved(module, x)
def test_autoregressive_type_A(self): batch_size = 10 shape = (6, 8, 8) x = torch.randn(batch_size, *shape) x_altered = copy.deepcopy(x) x_altered[:, 4, 4, 2] += 100.0 # Alter channel G of feature 2/2 in position (4,2) module = MaskedConv2d(6, 6, kernel_size=3, padding=1, mask_type='A') y = module(x) y_altered = module(x_altered) # Assert all pixels up to (4,2) are unaltered self.assertEqual(y[:, :, :4], y_altered[:, :, :4]) self.assertEqual(y[:, :, 4, :2], y_altered[:, :, 4, :2]) # Assert channel R is unaltered self.assertEqual(y[:, 0, 4, 2], y_altered[:, 0, 4, 2]) self.assertEqual(y[:, 3, 4, 2], y_altered[:, 3, 4, 2]) # Assert channel G is unaltered self.assertEqual(y[:, 1, 4, 2], y_altered[:, 1, 4, 2]) self.assertEqual(y[:, 4, 4, 2], y_altered[:, 4, 4, 2]) # Assert channel B is altered self.assertFalse((y[:, 2, 4, 2] == y_altered[:, 2, 4, 2]).view(-1).any()) self.assertFalse((y[:, 5, 4, 2] == y_altered[:, 5, 4, 2]).view(-1).any()) # Assert all elements in next pixel are altered self.assertFalse((y[:, :, 4, 3] == y_altered[:, :, 4, 3]).view(-1).any())
def test_bijection_is_well_behaved(self): batch_size = 10 shape = (3, 8, 8) x = torch.randn(batch_size, *shape) net_spatial = nn.Sequential( SpatialMaskedConv2d(3, 3 * 2, kernel_size=3, padding=1, mask_type='A'), ElementwiseParams2d(2)) net = nn.Sequential( MaskedConv2d(3, 3 * 2, kernel_size=3, padding=1, mask_type='A'), ElementwiseParams2d(2)) self.eps = 1e-6 for autoregressive_order in ['raster_cwh', 'raster_wh']: with self.subTest(autoregressive_order=autoregressive_order): if autoregressive_order == 'raster_cwh': autoreg_net = net elif autoregressive_order == 'raster_wh': autoreg_net = net_spatial bijection = AffineAutoregressiveBijection2d( autoreg_net, autoregressive_order=autoregressive_order) self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape))
def __init__(self, in_channels, num_params, filters=128, num_blocks=15, output_filters=1024, kernel_size=3, kernel_size_in=7, init_transforms=lambda x: 2 * x - 1): layers = [LambdaLayer(init_transforms)] +\ [MaskedConv2d(in_channels, 2 * filters, kernel_size=kernel_size_in, padding=kernel_size_in//2, mask_type='A', data_channels=in_channels)] +\ [MaskedResidualBlock2d(filters, data_channels=in_channels, kernel_size=kernel_size) for _ in range(num_blocks)] +\ [nn.ReLU(True), MaskedConv2d(2 * filters, output_filters, kernel_size=1, mask_type='B', data_channels=in_channels)] +\ [nn.ReLU(True), MaskedConv2d(output_filters, num_params * in_channels, kernel_size=1, mask_type='B', data_channels=in_channels)] +\ [ElementwiseParams2d(num_params)] super(PixelCNN, self).__init__(*layers)
def __init__(self, h, kernel_size=3, data_channels=3): super(MaskedResidualBlock2d, self).__init__() self.conv1 = MaskedConv2d(2 * h, h, kernel_size=1, mask_type='B', data_channels=data_channels) self.conv2 = MaskedConv2d(h, h, kernel_size=kernel_size, padding=kernel_size // 2, mask_type='B', data_channels=data_channels) self.conv3 = MaskedConv2d(h, 2 * h, kernel_size=1, mask_type='B', data_channels=data_channels)
def test_bijection_is_well_behaved(self): num_mix = 8 batch_size = 10 shape = (3,4,4) elementwise_params = 3 * num_mix x = torch.randn(batch_size, *shape) net_spatial = nn.Sequential(SpatialMaskedConv2d(3,3*elementwise_params, kernel_size=3, padding=1, mask_type='A'), ElementwiseParams2d(elementwise_params)) net = nn.Sequential(MaskedConv2d(3,3*elementwise_params, kernel_size=3, padding=1, mask_type='A'), ElementwiseParams2d(elementwise_params)) self.eps = 5e-5 bijection = LogisticMixtureAutoregressiveBijection2d(net, num_mixtures=num_mix, autoregressive_order='raster_cwh') self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape)) bijection = LogisticMixtureAutoregressiveBijection2d(net_spatial, num_mixtures=num_mix, autoregressive_order='raster_wh') self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape))
def __init__(self, channels, context_channels, params, autoregressive_order): super(CondNet, self).__init__() if autoregressive_order == 'raster_cwh': self.conv = MaskedConv2d(channels, channels * params, kernel_size=3, padding=1, mask_type='A') elif autoregressive_order == 'raster_wh': self.conv = SpatialMaskedConv2d(channels, channels * params, kernel_size=3, padding=1, mask_type='A') self.context = nn.Conv2d(context_channels, channels * params, kernel_size=1) self.out = ElementwiseParams2d(params)
def test_bijection_is_well_behaved(self): batch_size = 10 shape = (3, 8, 8) x = torch.randn(batch_size, *shape) net_spatial = SpatialMaskedConv2d(3, 3, kernel_size=3, padding=1, mask_type='A') net = MaskedConv2d(3, 3, kernel_size=3, padding=1, mask_type='A') self.eps = 1e-6 bijection = AdditiveAutoregressiveBijection2d( net, autoregressive_order='raster_cwh') self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape)) bijection = AdditiveAutoregressiveBijection2d( net_spatial, autoregressive_order='raster_wh') self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape))