def test_bijection_is_well_behaved(self): batch_size = 10 shape = (3, 8, 8) x = torch.randn(batch_size, *shape) net_spatial = nn.Sequential( SpatialMaskedConv2d(3, 3 * 2, kernel_size=3, padding=1, mask_type='A'), ElementwiseParams2d(2)) net = nn.Sequential( MaskedConv2d(3, 3 * 2, kernel_size=3, padding=1, mask_type='A'), ElementwiseParams2d(2)) self.eps = 1e-6 for autoregressive_order in ['raster_cwh', 'raster_wh']: with self.subTest(autoregressive_order=autoregressive_order): if autoregressive_order == 'raster_cwh': autoreg_net = net elif autoregressive_order == 'raster_wh': autoreg_net = net_spatial bijection = AffineAutoregressiveBijection2d( autoreg_net, autoregressive_order=autoregressive_order) self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape))
def __init__(self, in_channels, num_context, num_blocks, mid_channels, depth, growth, dropout, gated_conv, coupling_network): assert in_channels % 2 == 0 if coupling_network == "densenet": net = nn.Sequential( DenseNet(in_channels=in_channels // 2 + num_context, out_channels=in_channels, num_blocks=num_blocks, mid_channels=mid_channels, depth=depth, growth=growth, dropout=dropout, gated_conv=gated_conv, zero_init=True), ElementwiseParams2d(2, mode='sequential')) elif coupling_network == "conv": net = nn.Sequential( ConvNet(in_channels=in_channels // 2 + num_context, out_channels=in_channels, mid_channels=mid_channels, num_layers=depth, activation='relu'), ElementwiseParams2d(2, mode='sequential')) else: raise ValueError(f"Unknown coupling network {coupling_network}") super(ConditionalCoupling, self).__init__(coupling_net=net, scale_fn=scale_fn("tanh_exp"))
def test_shape(self): module = ElementwiseParams2d(3) y = module(self.x) expected_shape = (10, 2, 4, 4, 3) self.assertEqual(y.shape, expected_shape) module = ElementwiseParams2d(2) y = module(self.x) expected_shape = (10, 3, 4, 4, 2) self.assertEqual(y.shape, expected_shape)
def __init__(self, x_size, y_size, coupling_network, mid_channels, depth, num_blocks=None, dropout=None, gated_conv=None, checkerboard=False, flip=False): if checkerboard: in_channels = y_size[0] + x_size[0] out_channels = y_size[0] * 2 split_dim = 3 assert x_size[1] == y_size[1] and x_size[2] == y_size[2] // 2 else: in_channels = y_size[0] // 2 + x_size[0] out_channels = y_size[0] split_dim = 1 assert x_size[1] == y_size[1] and x_size[2] == y_size[2] assert y_size[ 0] % 2 == 0, f"High-resolution has shape {y_size} with channels not evenly divisible" if coupling_network == "densenet": coupling_net = nn.Sequential( DenseNet(in_channels=in_channels, out_channels=out_channels, num_blocks=num_blocks, mid_channels=mid_channels, depth=depth, growth=mid_channels, dropout=dropout, gated_conv=gated_conv, zero_init=True), ElementwiseParams2d(2, mode='sequential')) elif coupling_network == "conv": coupling_net = nn.Sequential( ConvNet(in_channels=in_channels, out_channels=out_channels, mid_channels=mid_channels, num_layers=depth, weight_norm=True, activation='relu'), ElementwiseParams2d(2, mode='sequential')) else: raise ValueError(f"Unknown coupling network {coupling_network}") super(SRCoupling, self).__init__(coupling_net=coupling_net, scale_fn=scale_fn("tanh_exp"), split_dim=split_dim, flip=flip)
def __init__(self, in_channels, num_context, num_blocks, mid_channels, depth, dropout, gated_conv, coupling_network, checkerboard=False, flip=False): if checkerboard: num_in = in_channels + num_context num_out = in_channels * 2 split_dim = 3 else: num_in = in_channels // 2 + num_context num_out = in_channels split_dim = 1 assert in_channels % 2 == 0 or split_dim != 1, f"in_channels = {in_channels} not evenly divisible" if coupling_network == "densenet": net = nn.Sequential( DenseNet(in_channels=num_in, out_channels=num_out, num_blocks=num_blocks, mid_channels=mid_channels, depth=depth, growth=mid_channels, dropout=dropout, gated_conv=gated_conv, zero_init=True), ElementwiseParams2d(2, mode='sequential')) elif coupling_network == "conv": net = nn.Sequential( ConvNet(in_channels=num_in, out_channels=num_out, mid_channels=mid_channels, num_layers=depth, activation='relu'), ElementwiseParams2d(2, mode='sequential')) else: raise ValueError(f"Unknown coupling network {coupling_network}") super(ConditionalCoupling, self).__init__(coupling_net=net, scale_fn=scale_fn("tanh_exp"), split_dim=split_dim, flip=flip)
def test_bijection_is_well_behaved(self): num_bins = 16 num_mix = 8 batch_size = 10 elementwise_params = 3 * num_mix self.eps = 1e-6 for shape in [(6, ), (6, 4, 4)]: for num_condition in [None, 1]: with self.subTest(shape=shape, num_condition=num_condition): x = torch.rand(batch_size, *shape) if num_condition is None: if len(shape) == 1: net = nn.Sequential( nn.Linear(3, 3 * elementwise_params), ElementwiseParams(elementwise_params)) if len(shape) == 3: net = nn.Sequential( nn.Conv2d(3, 3 * elementwise_params, kernel_size=3, padding=1), ElementwiseParams2d(elementwise_params)) else: if len(shape) == 1: net = nn.Sequential( nn.Linear(1, 5 * elementwise_params), ElementwiseParams(elementwise_params)) if len(shape) == 3: net = nn.Sequential( nn.Conv2d(1, 5 * elementwise_params, kernel_size=3, padding=1), ElementwiseParams2d(elementwise_params)) bijection = CensoredLogisticMixtureCouplingBijection( net, num_mixtures=num_mix, num_bins=num_bins, num_condition=num_condition) self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape)) z, _ = bijection.forward(x) if num_condition is None: self.assertEqual(x[:, :3], z[:, :3]) else: self.assertEqual(x[:, :1], z[:, :1])
def test_bijection_is_well_behaved(self): batch_size = 10 self.eps = 5e-6 for scale_str in ['exp', 'softplus', 'sigmoid', 'tanh_exp']: for shape in [(6, ), (6, 8, 8)]: for num_condition in [None, 1]: with self.subTest(shape=shape, num_condition=num_condition, scale_str=scale_str): x = torch.randn(batch_size, *shape) context = torch.randn(batch_size, *shape) if num_condition is None: if len(shape) == 1: net = nn.Sequential(nn.Linear(3 + 6, 3 * 2), ElementwiseParams(2)) if len(shape) == 3: net = nn.Sequential( nn.Conv2d(3 + 6, 3 * 2, kernel_size=3, padding=1), ElementwiseParams2d(2)) else: if len(shape) == 1: net = nn.Sequential(nn.Linear(1 + 6, 5 * 2), ElementwiseParams(2)) if len(shape) == 3: net = nn.Sequential( nn.Conv2d(1 + 6, 5 * 2, kernel_size=3, padding=1), ElementwiseParams2d(2)) bijection = ConditionalAffineCouplingBijection( net, num_condition=num_condition, scale_fn=scale_fn(scale_str)) self.assert_bijection_is_well_behaved( bijection, x, context, z_shape=(batch_size, *shape)) z, _ = bijection.forward(x, context) if num_condition is None: self.assertEqual(x[:, :3], z[:, :3]) else: self.assertEqual(x[:, :1], z[:, :1])
def __init__(self, x_size, y_size, mid_channels, num_blocks, num_mixtures, dropout, checkerboard=False, flip=False): if checkerboard: in_channels = y_size[0] split_dim = 3 assert x_size[1] == y_size[1] and x_size[2] == y_size[2] // 2 else: in_channels = y_size[0] // 2 split_dim = 1 assert x_size[1] == y_size[1] and x_size[2] == y_size[2] assert y_size[ 0] % 2 == 0, f"High-resolution has shape {y_size} with channels not evenly divisible" coupling_net = nn.Sequential( TransformerNet(in_channels=in_channels, context_channels=x_size[0], mid_channels=mid_channels, num_blocks=num_blocks, num_mixtures=num_mixtures, dropout=dropout), ElementwiseParams2d(2 + num_mixtures * 3, mode='sequential')) super(SRMixtureCoupling, self).__init__(coupling_net=coupling_net, num_mixtures=num_mixtures, scale_fn=scale_fn("tanh_exp"), split_dim=split_dim, flip=flip)
def __init__(self, in_channels, mid_channels, num_mixtures, num_blocks, dropout, checkerboard=False, flip=False): if checkerboard: num_in = in_channels split_dim = 3 else: num_in = in_channels // 2 split_dim = 1 net = nn.Sequential( TransformerNet(in_channels=num_in, mid_channels=mid_channels, num_blocks=num_blocks, num_mixtures=num_mixtures, dropout=dropout), ElementwiseParams2d(2 + num_mixtures * 3, mode='sequential')) super(MixtureCoupling, self).__init__(coupling_net=net, num_mixtures=num_mixtures, scale_fn=scale_fn("tanh_exp"), split_dim=split_dim, flip=flip)
def net(in_channels): return nn.Sequential( TransformerNet(in_channels // 2, mid_channels=16, num_blocks=2, num_mixtures=k, dropout=0.2), ElementwiseParams2d(2 + k * 3))
def net(channels): return nn.Sequential(ConvNet(in_channels=channels//2, out_channels=(channels - channels//2) * 2, mid_channels=64, num_layers=1, activation='relu'), ElementwiseParams2d(2))
def net(channels): return nn.Sequential(ConvNet(input_size=channels//2, output_size=channels, hidden_units=64, num_layers=1, activation='relu'), ElementwiseParams2d(2))
def test_mode(self): module = ElementwiseParams2d(2, mode='interleaved') y = module(self.x) self.assertEqual(y[:, 0], torch.stack([self.x[:, 0], self.x[:, 3]], dim=-1)) self.assertEqual(y[:, 1], torch.stack([self.x[:, 1], self.x[:, 4]], dim=-1)) self.assertEqual(y[:, 2], torch.stack([self.x[:, 2], self.x[:, 5]], dim=-1)) module = ElementwiseParams2d(2, mode='sequential') y = module(self.x) self.assertEqual(y[:, 0], torch.stack([self.x[:, 0], self.x[:, 1]], dim=-1)) self.assertEqual(y[:, 1], torch.stack([self.x[:, 2], self.x[:, 3]], dim=-1)) self.assertEqual(y[:, 2], torch.stack([self.x[:, 4], self.x[:, 5]], dim=-1))
def test_bijection_is_well_behaved(self): num_bins = 16 batch_size = 10 num_params = 2 * num_bins + 1 self.eps = 5e-3 for shape in [(6, ), (6, 8, 8)]: for num_condition in [None, 1]: with self.subTest(shape=shape, num_condition=num_condition): x = torch.rand(batch_size, *shape) if num_condition is None: if len(shape) == 1: net = nn.Sequential(nn.Linear(3, 3 * num_params), ElementwiseParams(num_params)) if len(shape) == 3: net = nn.Sequential( nn.Conv2d(3, 3 * num_params, kernel_size=3, padding=1), ElementwiseParams2d(num_params)) else: if len(shape) == 1: net = nn.Sequential(nn.Linear(1, 5 * num_params), ElementwiseParams(num_params)) if len(shape) == 3: net = nn.Sequential( nn.Conv2d(1, 5 * num_params, kernel_size=3, padding=1), ElementwiseParams2d(num_params)) bijection = QuadraticSplineCouplingBijection( net, num_bins=num_bins, num_condition=num_condition) self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape)) z, _ = bijection.forward(x) if num_condition is None: self.assertEqual(x[:, :3], z[:, :3]) else: self.assertEqual(x[:, :1], z[:, :1])
def net(channels): return nn.Sequential( DenseNet(in_channels=channels // 2, out_channels=channels, num_blocks=1, mid_channels=64, depth=8, growth=16, dropout=0.0, gated_conv=True, zero_init=True), ElementwiseParams2d(2))
def __init__(self, in_channels, mid_channels, num_mixtures, num_blocks, dropout): net = nn.Sequential( TransformerNet(in_channels // 2, mid_channels=mid_channels, num_blocks=num_blocks, num_mixtures=num_mixtures, dropout=dropout), ElementwiseParams2d(2 + num_mixtures * 3, mode='sequential')) super(MixtureCoupling, self).__init__(coupling_net=net, num_mixtures=num_mixtures, scale_fn=scale_fn("tanh_exp"))
def net(channels): return nn.Sequential( # DenseNet(in_channels=channels//2, # out_channels=channels, # num_blocks=1, # mid_channels=64, # depth=8, # growth=16, # dropout=0.0, # gated_conv=True, # zero_init=True), # nn.Conv1d(channels//2, channels, 1, bias=False), Net(in_channels=channels // 2, out_channels=channels, mid_channels=64), ElementwiseParams2d(2))
def __init__(self, in_channels, num_blocks, mid_channels, depth, growth, dropout, gated_conv): assert in_channels % 2 == 0 net = nn.Sequential(DenseNet(in_channels=in_channels//2, out_channels=in_channels, num_blocks=num_blocks, mid_channels=mid_channels, depth=depth, growth=growth, dropout=dropout, gated_conv=gated_conv, zero_init=True), ElementwiseParams2d(2, mode='sequential')) super(Coupling, self).__init__(coupling_net=net)
def __init__(self, in_channels, num_params, filters=128, num_blocks=15, output_filters=1024, kernel_size=3, kernel_size_in=7, init_transforms=lambda x: 2 * x - 1): layers = [LambdaLayer(init_transforms)] +\ [MaskedConv2d(in_channels, 2 * filters, kernel_size=kernel_size_in, padding=kernel_size_in//2, mask_type='A', data_channels=in_channels)] +\ [MaskedResidualBlock2d(filters, data_channels=in_channels, kernel_size=kernel_size) for _ in range(num_blocks)] +\ [nn.ReLU(True), MaskedConv2d(2 * filters, output_filters, kernel_size=1, mask_type='B', data_channels=in_channels)] +\ [nn.ReLU(True), MaskedConv2d(output_filters, num_params * in_channels, kernel_size=1, mask_type='B', data_channels=in_channels)] +\ [ElementwiseParams2d(num_params)] super(PixelCNN, self).__init__(*layers)
def __init__(self, channels, context_channels, params, autoregressive_order): super(CondNet, self).__init__() if autoregressive_order == 'raster_cwh': self.conv = MaskedConv2d(channels, channels * params, kernel_size=3, padding=1, mask_type='A') elif autoregressive_order == 'raster_wh': self.conv = SpatialMaskedConv2d(channels, channels * params, kernel_size=3, padding=1, mask_type='A') self.context = nn.Conv2d(context_channels, channels * params, kernel_size=1) self.out = ElementwiseParams2d(params)
def __init__(self, in_channels, num_context, mid_channels, num_mixtures, num_blocks, dropout, use_attn=True): coupling_net = nn.Sequential( TransformerNet(in_channels // 2, context_channels=num_context, mid_channels=mid_channels, num_blocks=num_blocks, num_mixtures=num_mixtures, use_attn=use_attn, dropout=dropout), ElementwiseParams2d(2 + num_mixtures * 3, mode='sequential')) super(ConditionalMixtureCoupling, self).__init__(coupling_net=coupling_net, num_mixtures=num_mixtures, scale_fn=scale_fn("tanh_exp"))
def test_layer_is_well_behaved(self): module = ElementwiseParams2d(3) self.assert_layer_is_well_behaved(module, self.x)
def test_bijection_is_well_behaved(self): num_bins = 4 num_mix = 8 batch_size = 10 shape = (3,4,4) elementwise_params = 3 * num_mix x = torch.rand(batch_size, *shape) net_spatial = nn.Sequential(SpatialMaskedConv2d(3,3*elementwise_params, kernel_size=3, padding=1, mask_type='A'), ElementwiseParams2d(elementwise_params)) net = nn.Sequential(MaskedConv2d(3,3*elementwise_params, kernel_size=3, padding=1, mask_type='A'), ElementwiseParams2d(elementwise_params)) self.eps = 1e-6 bijection = CensoredLogisticMixtureAutoregressiveBijection2d(net, num_mixtures=num_mix, num_bins=num_bins, autoregressive_order='raster_cwh') self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape)) bijection = CensoredLogisticMixtureAutoregressiveBijection2d(net_spatial, num_mixtures=num_mix, num_bins=num_bins, autoregressive_order='raster_wh') self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape))