def test_surjection_is_well_behaved(self):
        batch_size = 10
        shape = [8, 4, 4]
        num_bits_list = [2, 5, 8]

        for num_bits in num_bits_list:
            with self.subTest(num_bits=num_bits):
                x = torch.randint(0, 2**num_bits,
                                  (batch_size, ) + torch.Size(shape))
                encoder = ConditionalInverseFlow(
                    base_dist=DiagonalNormal(shape),
                    transforms=[
                        ConditionalAffineBijection(
                            nn.Sequential(
                                LambdaLayer(lambda x: 2 * x.float() /
                                            (2**num_bits - 1) - 1),
                                nn.Conv2d(shape[0],
                                          2 * shape[0],
                                          kernel_size=3,
                                          padding=1))),
                        Sigmoid()
                    ])
                surjection = VariationalDequantization(encoder,
                                                       num_bits=num_bits)
                self.assert_surjection_is_well_behaved(surjection,
                                                       x,
                                                       z_shape=(batch_size,
                                                                *shape),
                                                       z_dtype=torch.float)
Example #2
0
    def __init__(self, data_shape, augment_size, num_steps,
                 mid_channels, num_context, num_blocks, dropout, num_mixtures, checkerboard=True, tuple_flip=True, coupling_network="transformer"):

        context_in = data_shape[0] * 2 if checkerboard else data_shape[0]
        layers = []
        if checkerboard: layers.append(Checkerboard(concat_dim=1))
        layers += [Conv2d(context_in, mid_channels // 2, kernel_size=3, stride=1),
                   nn.Conv2d(mid_channels // 2, mid_channels, kernel_size=2, stride=2, padding=0),
                   GatedConvNet(channels=mid_channels, num_blocks=2, dropout=0.0),
                   Conv2dZeros(in_channels=mid_channels, out_channels=num_context)]
        context_net = nn.Sequential(*layers)
        
        # layer transformations of the augment flow
        transforms = []
        sample_shape = (augment_size * 4, data_shape[1] // 2, data_shape[2] // 2)
        for i in range(num_steps):
            flip = (i % 2 == 0) if tuple_flip else False
            transforms.append(ActNormBijection2d(sample_shape[0]))
            transforms.extend([Conv1x1(sample_shape[0])])
            if coupling_network in ["conv", "densenet"]:
                # just included for debugging
                transforms.append(
                    ConditionalCoupling(in_channels=sample_shape[0],
                                        num_context=num_context,
                                        num_blocks=num_blocks,
                                        mid_channels=mid_channels,
                                        depth=1,
                                        dropout=dropout,
                                        gated_conv=False,
                                        coupling_network=coupling_network,
                                        checkerboard=checkerboard,
                                        flip=flip))
            
            elif coupling_network == "transformer":
                transforms.append(
                    ConditionalMixtureCoupling(in_channels=sample_shape[0],
                                               num_context=num_context,
                                               mid_channels=mid_channels,
                                               num_mixtures=num_mixtures,
                                               num_blocks=num_blocks,
                                               dropout=dropout,
                                               use_attn=False,
                                               checkerboard=checkerboard,
                                               flip=flip))
            else:
                raise ValueError(f"Unknown network type {coupling_network}")

        # Final shuffle of channels, squeeze and sigmoid
        transforms.extend([Conv1x1(sample_shape[0]),
                           Unsqueeze2d(),
                           Sigmoid()])
        super(AugmentFlow, self).__init__(base_dist=StandardNormal(sample_shape), # ConvNormal2d(sample_shape),
                                          transforms=transforms,
                                          context_init=context_net)
    def __init__(self, data_shape, num_bits, num_steps, num_context,
                 num_blocks, mid_channels, depth, growth, dropout, gated_conv):

        context_net = nn.Sequential(LambdaLayer(lambda x: 2*x.float()/(2**num_bits-1)-1),
                                    DenseBlock(in_channels=data_shape[0],
                                               out_channels=mid_channels,
                                               depth=4,
                                               growth=16,
                                               dropout=dropout,
                                               gated_conv=gated_conv,
                                               zero_init=False),
                                    nn.Conv2d(mid_channels, mid_channels, kernel_size=2, stride=2, padding=0),
                                    DenseBlock(in_channels=mid_channels,
                                               out_channels=num_context,
                                               depth=4,
                                               growth=16,
                                               dropout=dropout,
                                               gated_conv=gated_conv,
                                               zero_init=False))



        transforms = []
        sample_shape = (data_shape[0] * 4, data_shape[1] // 2, data_shape[2] // 2)
        for i in range(num_steps):
            transforms.extend([
                Conv1x1(sample_shape[0]),
                ConditionalCoupling(in_channels=sample_shape[0],
                                    num_context=num_context,
                                    num_blocks=num_blocks,
                                    mid_channels=mid_channels,
                                    depth=depth,
                                    growth=growth,
                                    dropout=dropout,
                                    gated_conv=gated_conv)
            ])

        # Final shuffle of channels, squeeze and sigmoid
        transforms.extend([Conv1x1(sample_shape[0]),
                           Unsqueeze2d(),
                           Sigmoid()
                          ])

        super(DequantizationFlow, self).__init__(base_dist=ConvNormal2d(sample_shape),
                                                 transforms=transforms,
                                                 context_init=context_net)
    def test_range(self):
        batch_size = 10
        shape = [8]

        z = torch.randn(batch_size, *shape)
        encoder = ConditionalInverseFlow(
            base_dist=DiagonalNormal(shape),
            transforms=[
                ConditionalAffineBijection(
                    nn.Sequential(
                        LambdaLayer(lambda x: 2 * x.float() / 255 - 1),
                        nn.Linear(shape[0], 2 * shape[0]))),
                Sigmoid()
            ])
        surjection = VariationalDequantization(encoder, num_bits=8)
        x = surjection.inverse(z)
        self.assertTrue(x.min() >= 0)
        self.assertTrue(x.max() <= 255)
    def test_pairs(self):
        batch_size = 10
        shape = [2, 3, 4]
        x_normal = torch.randn(batch_size, *shape)
        x_uniform = torch.rand(batch_size, *shape)
        bijections = [
            (Sigmoid(), Logit(), x_normal, 1e-5),
            (Softplus(), SoftplusInverse(), x_normal, 1e-5),
        ]


        for bijection, bijection_inv, x, eps in bijections:
            with self.subTest(bijection=bijection):
                self.eps = eps
                z, ldj = bijection(x)
                xr, ldjr = bijection_inv(z)
                self.assertEqual(x, xr)
                self.assertEqual(ldj, -ldjr)
    def test_bijection_is_well_behaved(self):
        batch_size = 10
        shape = [2, 3, 4]
        x_normal = torch.randn(batch_size, *shape)
        x_uniform = torch.rand(batch_size, *shape)
        bijections = [
            (LeakyReLU(), x_normal, 1e-6),
            (SneakyReLU(), x_normal, 1e-6),
            (Tanh(), x_normal, 1e-4),
            (Sigmoid(), x_normal, 1e-5),
            (Logit(), x_uniform, 1e-6),
            (Softplus(), x_normal, 1e-6),
            (SoftplusInverse(), F.softplus(x_normal), 1e-6),
        ]


        for bijection, x, eps in bijections:
            with self.subTest(bijection=bijection):
                self.eps = eps
                self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape))
Example #7
0
# Flow
transforms = []
for _ in range(args.num_flows):
    net = nn.Sequential(MLP(I, P*O,
                            hidden_units=args.hidden_units,
                            activation=args.activation),
                        ElementwiseParams(P))
    if args.affine: transforms.append(AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn), num_condition=I))
    else:           transforms.append(AdditiveCouplingBijection(net, num_condition=I))
    if args.actnorm: transforms.append(ActNormBijection(D))
    if args.permutation == 'reverse':   transforms.append(Reverse(D))
    elif args.permutation == 'shuffle': transforms.append(Shuffle(D))
transforms.pop()
if args.num_bits is not None:
    transforms.append(Sigmoid())
    transforms.append(VariationalQuantization(decoder, num_bits=args.num_bits))


pi = Flow(base_dist=target,
          transforms=transforms).to(args.device)

p = StandardNormal(shape).to(args.device)

#######################
## Specify optimizer ##
#######################

if args.optimizer == 'adam':
    optimizer = Adam(pi.parameters(), lr=args.lr)
elif args.optimizer == 'adamax':
Example #8
0
    def __init__(self,
                 data_shape,
                 num_bits,
                 num_steps,
                 coupling_network,
                 num_context,
                 num_blocks,
                 mid_channels,
                 depth,
                 growth=None,
                 dropout=None,
                 gated_conv=None,
                 num_mixtures=None):

        #context_network_type = "conv"
        context_network_type = coupling_network

        if context_network_type == "densenet":
            context_net = nn.Sequential(
                LambdaLayer(lambda x: 2 * x.float() / (2**num_bits - 1) - 1),
                DenseBlock(in_channels=data_shape[0],
                           out_channels=mid_channels,
                           depth=depth,
                           growth=growth,
                           dropout=dropout,
                           gated_conv=gated_conv,
                           zero_init=False),
                nn.Conv2d(mid_channels,
                          mid_channels,
                          kernel_size=2,
                          stride=2,
                          padding=0),
                DenseBlock(in_channels=mid_channels,
                           out_channels=num_context,
                           depth=depth,
                           growth=growth,
                           dropout=dropout,
                           gated_conv=gated_conv,
                           zero_init=False))

        elif context_network_type == "transformer":
            layers = [
                LambdaLayer(lambda x: 2 * x.float() / (2**num_bits - 1) - 1),
                Conv2d(in_channels=data_shape[0],
                       out_channels=mid_channels,
                       kernel_size=3,
                       padding=1),
                nn.Conv2d(mid_channels,
                          mid_channels,
                          kernel_size=2,
                          stride=2,
                          padding=0)
            ]
            for i in range(num_blocks):
                layers.append(
                    ConvAttnBlock(channels=mid_channels,
                                  dropout=0.0,
                                  use_attn=False,
                                  context_channels=None))
            layers.append(
                Conv2d(in_channels=mid_channels,
                       out_channels=num_context,
                       kernel_size=3,
                       padding=1))
            context_net = nn.Sequential(*layers)

        elif context_network_type == "conv":
            context_net = nn.Sequential(
                LambdaLayer(lambda x: 2 * x.float() / (2**num_bits - 1) - 1),
                Conv2d(data_shape[0],
                       mid_channels // 2,
                       kernel_size=3,
                       stride=1),
                nn.Conv2d(mid_channels // 2,
                          mid_channels,
                          kernel_size=2,
                          stride=2,
                          padding=0),
                GatedConvNet(channels=mid_channels, num_blocks=2, dropout=0.0),
                Conv2dZeros(in_channels=mid_channels,
                            out_channels=num_context))
        else:
            raise ValueError(
                f"Unknown dequantization context_network_type type: {context_network_type}"
            )

        # layer transformations of the dequantization flow
        transforms = []
        sample_shape = (data_shape[0] * 4, data_shape[1] // 2,
                        data_shape[2] // 2)
        for i in range(num_steps):
            #transforms.append(ActNormBijection2d(sample_shape[0]))
            transforms.extend([Conv1x1(sample_shape[0])])

            if coupling_network in ["conv", "densenet"]:
                transforms.append(
                    ConditionalCoupling(in_channels=sample_shape[0],
                                        num_context=num_context,
                                        num_blocks=num_blocks,
                                        mid_channels=mid_channels,
                                        depth=depth,
                                        growth=growth,
                                        dropout=dropout,
                                        gated_conv=gated_conv,
                                        coupling_network=coupling_network))
            elif coupling_network == "transformer":
                transforms.append(
                    ConditionalMixtureCoupling(in_channels=sample_shape[0],
                                               num_context=num_context,
                                               mid_channels=mid_channels,
                                               num_mixtures=num_mixtures,
                                               num_blocks=num_blocks,
                                               dropout=dropout,
                                               use_attn=False))
            else:
                raise ValueError(
                    f"Unknown dequantization coupling network type: {coupling_network}"
                )

        # Final shuffle of channels, squeeze and sigmoid
        transforms.extend([Conv1x1(sample_shape[0]), Unsqueeze2d(), Sigmoid()])
        super(DequantizationFlow,
              self).__init__(base_dist=ConvNormal2d(sample_shape),
                             transforms=transforms,
                             context_init=context_net)