def test_surjection_is_well_behaved(self): batch_size = 10 shape = [8, 4, 4] num_bits_list = [2, 5, 8] for num_bits in num_bits_list: with self.subTest(num_bits=num_bits): x = torch.randint(0, 2**num_bits, (batch_size, ) + torch.Size(shape)) encoder = ConditionalInverseFlow( base_dist=DiagonalNormal(shape), transforms=[ ConditionalAffineBijection( nn.Sequential( LambdaLayer(lambda x: 2 * x.float() / (2**num_bits - 1) - 1), nn.Conv2d(shape[0], 2 * shape[0], kernel_size=3, padding=1))), Sigmoid() ]) surjection = VariationalDequantization(encoder, num_bits=num_bits) self.assert_surjection_is_well_behaved(surjection, x, z_shape=(batch_size, *shape), z_dtype=torch.float)
def __init__(self, data_shape, augment_size, num_steps, mid_channels, num_context, num_blocks, dropout, num_mixtures, checkerboard=True, tuple_flip=True, coupling_network="transformer"): context_in = data_shape[0] * 2 if checkerboard else data_shape[0] layers = [] if checkerboard: layers.append(Checkerboard(concat_dim=1)) layers += [Conv2d(context_in, mid_channels // 2, kernel_size=3, stride=1), nn.Conv2d(mid_channels // 2, mid_channels, kernel_size=2, stride=2, padding=0), GatedConvNet(channels=mid_channels, num_blocks=2, dropout=0.0), Conv2dZeros(in_channels=mid_channels, out_channels=num_context)] context_net = nn.Sequential(*layers) # layer transformations of the augment flow transforms = [] sample_shape = (augment_size * 4, data_shape[1] // 2, data_shape[2] // 2) for i in range(num_steps): flip = (i % 2 == 0) if tuple_flip else False transforms.append(ActNormBijection2d(sample_shape[0])) transforms.extend([Conv1x1(sample_shape[0])]) if coupling_network in ["conv", "densenet"]: # just included for debugging transforms.append( ConditionalCoupling(in_channels=sample_shape[0], num_context=num_context, num_blocks=num_blocks, mid_channels=mid_channels, depth=1, dropout=dropout, gated_conv=False, coupling_network=coupling_network, checkerboard=checkerboard, flip=flip)) elif coupling_network == "transformer": transforms.append( ConditionalMixtureCoupling(in_channels=sample_shape[0], num_context=num_context, mid_channels=mid_channels, num_mixtures=num_mixtures, num_blocks=num_blocks, dropout=dropout, use_attn=False, checkerboard=checkerboard, flip=flip)) else: raise ValueError(f"Unknown network type {coupling_network}") # Final shuffle of channels, squeeze and sigmoid transforms.extend([Conv1x1(sample_shape[0]), Unsqueeze2d(), Sigmoid()]) super(AugmentFlow, self).__init__(base_dist=StandardNormal(sample_shape), # ConvNormal2d(sample_shape), transforms=transforms, context_init=context_net)
def __init__(self, data_shape, num_bits, num_steps, num_context, num_blocks, mid_channels, depth, growth, dropout, gated_conv): context_net = nn.Sequential(LambdaLayer(lambda x: 2*x.float()/(2**num_bits-1)-1), DenseBlock(in_channels=data_shape[0], out_channels=mid_channels, depth=4, growth=16, dropout=dropout, gated_conv=gated_conv, zero_init=False), nn.Conv2d(mid_channels, mid_channels, kernel_size=2, stride=2, padding=0), DenseBlock(in_channels=mid_channels, out_channels=num_context, depth=4, growth=16, dropout=dropout, gated_conv=gated_conv, zero_init=False)) transforms = [] sample_shape = (data_shape[0] * 4, data_shape[1] // 2, data_shape[2] // 2) for i in range(num_steps): transforms.extend([ Conv1x1(sample_shape[0]), ConditionalCoupling(in_channels=sample_shape[0], num_context=num_context, num_blocks=num_blocks, mid_channels=mid_channels, depth=depth, growth=growth, dropout=dropout, gated_conv=gated_conv) ]) # Final shuffle of channels, squeeze and sigmoid transforms.extend([Conv1x1(sample_shape[0]), Unsqueeze2d(), Sigmoid() ]) super(DequantizationFlow, self).__init__(base_dist=ConvNormal2d(sample_shape), transforms=transforms, context_init=context_net)
def test_range(self): batch_size = 10 shape = [8] z = torch.randn(batch_size, *shape) encoder = ConditionalInverseFlow( base_dist=DiagonalNormal(shape), transforms=[ ConditionalAffineBijection( nn.Sequential( LambdaLayer(lambda x: 2 * x.float() / 255 - 1), nn.Linear(shape[0], 2 * shape[0]))), Sigmoid() ]) surjection = VariationalDequantization(encoder, num_bits=8) x = surjection.inverse(z) self.assertTrue(x.min() >= 0) self.assertTrue(x.max() <= 255)
def test_pairs(self): batch_size = 10 shape = [2, 3, 4] x_normal = torch.randn(batch_size, *shape) x_uniform = torch.rand(batch_size, *shape) bijections = [ (Sigmoid(), Logit(), x_normal, 1e-5), (Softplus(), SoftplusInverse(), x_normal, 1e-5), ] for bijection, bijection_inv, x, eps in bijections: with self.subTest(bijection=bijection): self.eps = eps z, ldj = bijection(x) xr, ldjr = bijection_inv(z) self.assertEqual(x, xr) self.assertEqual(ldj, -ldjr)
def test_bijection_is_well_behaved(self): batch_size = 10 shape = [2, 3, 4] x_normal = torch.randn(batch_size, *shape) x_uniform = torch.rand(batch_size, *shape) bijections = [ (LeakyReLU(), x_normal, 1e-6), (SneakyReLU(), x_normal, 1e-6), (Tanh(), x_normal, 1e-4), (Sigmoid(), x_normal, 1e-5), (Logit(), x_uniform, 1e-6), (Softplus(), x_normal, 1e-6), (SoftplusInverse(), F.softplus(x_normal), 1e-6), ] for bijection, x, eps in bijections: with self.subTest(bijection=bijection): self.eps = eps self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape))
# Flow transforms = [] for _ in range(args.num_flows): net = nn.Sequential(MLP(I, P*O, hidden_units=args.hidden_units, activation=args.activation), ElementwiseParams(P)) if args.affine: transforms.append(AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn), num_condition=I)) else: transforms.append(AdditiveCouplingBijection(net, num_condition=I)) if args.actnorm: transforms.append(ActNormBijection(D)) if args.permutation == 'reverse': transforms.append(Reverse(D)) elif args.permutation == 'shuffle': transforms.append(Shuffle(D)) transforms.pop() if args.num_bits is not None: transforms.append(Sigmoid()) transforms.append(VariationalQuantization(decoder, num_bits=args.num_bits)) pi = Flow(base_dist=target, transforms=transforms).to(args.device) p = StandardNormal(shape).to(args.device) ####################### ## Specify optimizer ## ####################### if args.optimizer == 'adam': optimizer = Adam(pi.parameters(), lr=args.lr) elif args.optimizer == 'adamax':
def __init__(self, data_shape, num_bits, num_steps, coupling_network, num_context, num_blocks, mid_channels, depth, growth=None, dropout=None, gated_conv=None, num_mixtures=None): #context_network_type = "conv" context_network_type = coupling_network if context_network_type == "densenet": context_net = nn.Sequential( LambdaLayer(lambda x: 2 * x.float() / (2**num_bits - 1) - 1), DenseBlock(in_channels=data_shape[0], out_channels=mid_channels, depth=depth, growth=growth, dropout=dropout, gated_conv=gated_conv, zero_init=False), nn.Conv2d(mid_channels, mid_channels, kernel_size=2, stride=2, padding=0), DenseBlock(in_channels=mid_channels, out_channels=num_context, depth=depth, growth=growth, dropout=dropout, gated_conv=gated_conv, zero_init=False)) elif context_network_type == "transformer": layers = [ LambdaLayer(lambda x: 2 * x.float() / (2**num_bits - 1) - 1), Conv2d(in_channels=data_shape[0], out_channels=mid_channels, kernel_size=3, padding=1), nn.Conv2d(mid_channels, mid_channels, kernel_size=2, stride=2, padding=0) ] for i in range(num_blocks): layers.append( ConvAttnBlock(channels=mid_channels, dropout=0.0, use_attn=False, context_channels=None)) layers.append( Conv2d(in_channels=mid_channels, out_channels=num_context, kernel_size=3, padding=1)) context_net = nn.Sequential(*layers) elif context_network_type == "conv": context_net = nn.Sequential( LambdaLayer(lambda x: 2 * x.float() / (2**num_bits - 1) - 1), Conv2d(data_shape[0], mid_channels // 2, kernel_size=3, stride=1), nn.Conv2d(mid_channels // 2, mid_channels, kernel_size=2, stride=2, padding=0), GatedConvNet(channels=mid_channels, num_blocks=2, dropout=0.0), Conv2dZeros(in_channels=mid_channels, out_channels=num_context)) else: raise ValueError( f"Unknown dequantization context_network_type type: {context_network_type}" ) # layer transformations of the dequantization flow transforms = [] sample_shape = (data_shape[0] * 4, data_shape[1] // 2, data_shape[2] // 2) for i in range(num_steps): #transforms.append(ActNormBijection2d(sample_shape[0])) transforms.extend([Conv1x1(sample_shape[0])]) if coupling_network in ["conv", "densenet"]: transforms.append( ConditionalCoupling(in_channels=sample_shape[0], num_context=num_context, num_blocks=num_blocks, mid_channels=mid_channels, depth=depth, growth=growth, dropout=dropout, gated_conv=gated_conv, coupling_network=coupling_network)) elif coupling_network == "transformer": transforms.append( ConditionalMixtureCoupling(in_channels=sample_shape[0], num_context=num_context, mid_channels=mid_channels, num_mixtures=num_mixtures, num_blocks=num_blocks, dropout=dropout, use_attn=False)) else: raise ValueError( f"Unknown dequantization coupling network type: {coupling_network}" ) # Final shuffle of channels, squeeze and sigmoid transforms.extend([Conv1x1(sample_shape[0]), Unsqueeze2d(), Sigmoid()]) super(DequantizationFlow, self).__init__(base_dist=ConvNormal2d(sample_shape), transforms=transforms, context_init=context_net)