Example #1
0
    def test_surjection_is_well_behaved(self):
        batch_size = 10
        shape = [8, 4, 4]
        x = torch.randn(batch_size, *shape)
        surjections = [
            (Augment(StandardNormal([1, 4, 4]), x_size=8,
                     split_dim=1), (9, 4, 4)),
            (Augment(StandardNormal([4, 4, 4]), x_size=8,
                     split_dim=1), (12, 4, 4)),
            (Augment(StandardNormal([7, 4, 4]), x_size=8,
                     split_dim=1), (15, 4, 4)),
            (Augment(StandardNormal([8, 2, 4]), x_size=4,
                     split_dim=2), (8, 6, 4)),
            (Augment(StandardNormal([8, 4, 2]), x_size=4,
                     split_dim=3), (8, 4, 6)),
            (Augment(ConditionalMeanNormal(nn.Conv2d(8, 1, kernel_size=1)),
                     x_size=8,
                     split_dim=1), (9, 4, 4)),
            (Augment(ConditionalMeanNormal(nn.Conv2d(8, 4, kernel_size=1)),
                     x_size=8,
                     split_dim=1), (12, 4, 4)),
            (Augment(ConditionalMeanNormal(nn.Conv2d(8, 7, kernel_size=1)),
                     x_size=8,
                     split_dim=1), (15, 4, 4))
        ]

        for surjection, new_shape in surjections:
            with self.subTest(surjection=surjection):
                self.assert_surjection_is_well_behaved(surjection,
                                                       x,
                                                       z_shape=(batch_size,
                                                                *new_shape),
                                                       z_dtype=x.dtype)
Example #2
0
    def __init__(self, num_flows, actnorm, affine, scale_fn_str, hidden_units,
                 activation, range_flow, augment_size, base_dist, cond_size):

        D = 2  # Number of data dimensions
        A = D + augment_size  # Number of augmented data dimensions
        P = 2 if affine else 1  # Number of elementwise parameters

        # initialize context. Only upsample context in ContextInit if latent shape doesn't change during the flow.
        context_init = MLP(input_size=cond_size,
                           output_size=D,
                           hidden_units=hidden_units,
                           activation=activation)

        # initialize flow with either augmentation or Abs surjection
        if augment_size > 0:
            assert augment_size % 2 == 0
            transforms = [Augment(StandardNormal((augment_size, )), x_size=D)]

        else:
            transforms = []
            transforms = [SimpleAbsSurjection()]
            if range_flow == 'logit':
                transforms += [
                    ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 4]])),
                    Logit()
                ]
            elif range_flow == 'softplus':
                transforms += [SoftplusInverse()]

        # apply coupling layer flows
        for _ in range(num_flows):
            net = nn.Sequential(
                MLP(A // 2 + D,
                    P * A // 2,
                    hidden_units=hidden_units,
                    activation=activation), ElementwiseParams(P))
            if affine:
                transforms.append(
                    ConditionalAffineCouplingBijection(
                        net, scale_fn=scale_fn(scale_fn_str)))
            else:
                transforms.append(ConditionalAdditiveCouplingBijection(net))
            if actnorm: transforms.append(ActNormBijection(D))
            transforms.append(Reverse(A))

        transforms.pop()

        if base_dist == "uniform":
            base = StandardUniform((A, ))
        else:
            base = StandardNormal((A, ))

        super(SRFlow, self).__init__(base_dist=base,
                                     transforms=transforms,
                                     context_init=context_init)
Example #3
0
    def test_matching_torchdist(self):
        batch_size = 10
        shape = [3, 2, 2]
        x = torch.randn(batch_size, *shape)

        distribution = StandardNormal(shape)
        torchdist = torch.distributions.Normal(torch.zeros(shape),
                                               torch.ones(shape))

        log_prob = distribution.log_prob(x)
        log_prob2 = torchdist.log_prob(x).reshape(batch_size, -1).sum(1)

        self.assertEqual(log_prob, log_prob2)
Example #4
0
def reduction_layer(channels, items):
    return [
        *perm_norm_bi(channels),
        *perm_norm_bi(channels),
        *perm_norm_bi(channels),
        Squeeze2d(4),
        Slice(StandardNormal((channels * 2, items)), num_keep=channels * 2),
    ]
Example #5
0
    def test_surjection_is_well_behaved(self):
        batch_size = 10
        shape = [8, 4, 4]
        x = torch.randn(batch_size, *shape)
        surjections = [
            (Slice(StandardNormal([1, 4, 4]), num_keep=7, dim=1), (7,4,4)),
            (Slice(StandardNormal([4, 4, 4]), num_keep=4, dim=1), (4,4,4)),
            (Slice(StandardNormal([7, 4, 4]), num_keep=1, dim=1), (1,4,4)),
            (Slice(StandardNormal([8, 2, 4]), num_keep=2, dim=2), (8,2,4)),
            (Slice(StandardNormal([8, 4, 2]), num_keep=2, dim=3), (8,4,2)),
            (Slice(ConditionalMeanNormal(nn.Conv2d(7, 1, kernel_size=1)), num_keep=7, dim=1), (7,4,4)),
            (Slice(ConditionalMeanNormal(nn.Conv2d(4, 4, kernel_size=1)), num_keep=4, dim=1), (4,4,4)),
            (Slice(ConditionalMeanNormal(nn.Conv2d(1, 7, kernel_size=1)), num_keep=1, dim=1), (1,4,4))
        ]

        for surjection, new_shape in surjections:
            with self.subTest(surjection=surjection):
                self.assert_surjection_is_well_behaved(surjection, x, z_shape=(batch_size, *new_shape), z_dtype=x.dtype)
Example #6
0
    def test_distribution_is_well_behaved(self):
        batch_size = 10
        shape = [3, 2, 2]
        x = torch.randn(batch_size, *shape)
        distribution = StandardNormal(shape)

        self.assert_distribution_is_well_behaved(distribution,
                                                 x,
                                                 expected_shape=(batch_size,
                                                                 *shape))
Example #7
0
    def __init__(self, pretrained_model, latent_size):
        self.flow_shape = pretrained_model.base_dist.shape
        self.latent_size = latent_size

        # initialize transforms with first scale of pretrained models
        transforms = pretrained_model.transforms[0:28]

        # Replace slice layer with a compression flow
        mencoder = ConditionalNormal(ConvEncoderNet(
            in_channels=48,
            out_channels=768,
            mid_channels=[64, 128, 256],
            max_pool=True,
            batch_norm=True),
                                     split_dim=1)
        mdecoder = ConditionalNormal(ConvDecoderNet(
            in_channels=768,
            out_shape=(48 * 2, 8, 8),
            mid_channels=[256, 128, 64],
            batch_norm=True,
            in_lambda=lambda x: x.view(x.shape[0], x.shape[1], 1, 1)),
                                     split_dim=1)

        mid_vae = VAE(encoder=mencoder, decoder=mdecoder)
        reshape = Reshape(input_shape=(768, ), output_shape=(12, 8, 8))
        transforms.extend([mid_vae, reshape])

        # Non-dimension preserving flows
        current_shape = pretrained_model.base_dist.shape
        flat_dim = current_shape[0] * current_shape[1] * current_shape[2]
        fencoder = ConditionalNormal(
            MLP(flat_dim,
                2 * latent_size,
                hidden_units=[512, 256],
                activation='relu',
                in_lambda=lambda x: x.view(x.shape[0], flat_dim)))
        fdecoder = ConditionalNormal(MLP(
            latent_size,
            2 * flat_dim,
            hidden_units=[256, 512],
            activation='relu',
            out_lambda=lambda x: x.view(x.shape[0], current_shape[0] * 2,
                                        current_shape[1], current_shape[2])),
                                     split_dim=1)

        # append last scale of pretrained model and extend with the compressive VAE
        transforms.extend(pretrained_model.transforms[29:])
        final_vae = VAE(encoder=fencoder, decoder=fdecoder)
        transforms.append(final_vae)

        # Base distribution for non-dimension preserving portion of flow
        base1 = StandardNormal((latent_size, ))

        super(CompressPretrained, self).__init__(base_dist=[None, base1],
                                                 transforms=transforms)
Example #8
0
    def __init__(self, data_shape, augment_size, num_steps,
                 mid_channels, num_context, num_blocks, dropout, num_mixtures, checkerboard=True, tuple_flip=True, coupling_network="transformer"):

        context_in = data_shape[0] * 2 if checkerboard else data_shape[0]
        layers = []
        if checkerboard: layers.append(Checkerboard(concat_dim=1))
        layers += [Conv2d(context_in, mid_channels // 2, kernel_size=3, stride=1),
                   nn.Conv2d(mid_channels // 2, mid_channels, kernel_size=2, stride=2, padding=0),
                   GatedConvNet(channels=mid_channels, num_blocks=2, dropout=0.0),
                   Conv2dZeros(in_channels=mid_channels, out_channels=num_context)]
        context_net = nn.Sequential(*layers)
        
        # layer transformations of the augment flow
        transforms = []
        sample_shape = (augment_size * 4, data_shape[1] // 2, data_shape[2] // 2)
        for i in range(num_steps):
            flip = (i % 2 == 0) if tuple_flip else False
            transforms.append(ActNormBijection2d(sample_shape[0]))
            transforms.extend([Conv1x1(sample_shape[0])])
            if coupling_network in ["conv", "densenet"]:
                # just included for debugging
                transforms.append(
                    ConditionalCoupling(in_channels=sample_shape[0],
                                        num_context=num_context,
                                        num_blocks=num_blocks,
                                        mid_channels=mid_channels,
                                        depth=1,
                                        dropout=dropout,
                                        gated_conv=False,
                                        coupling_network=coupling_network,
                                        checkerboard=checkerboard,
                                        flip=flip))
            
            elif coupling_network == "transformer":
                transforms.append(
                    ConditionalMixtureCoupling(in_channels=sample_shape[0],
                                               num_context=num_context,
                                               mid_channels=mid_channels,
                                               num_mixtures=num_mixtures,
                                               num_blocks=num_blocks,
                                               dropout=dropout,
                                               use_attn=False,
                                               checkerboard=checkerboard,
                                               flip=flip))
            else:
                raise ValueError(f"Unknown network type {coupling_network}")

        # Final shuffle of channels, squeeze and sigmoid
        transforms.extend([Conv1x1(sample_shape[0]),
                           Unsqueeze2d(),
                           Sigmoid()])
        super(AugmentFlow, self).__init__(base_dist=StandardNormal(sample_shape), # ConvNormal2d(sample_shape),
                                          transforms=transforms,
                                          context_init=context_net)
    def test_data_parallel(self):
        batch_size = 12
        shape = [2, 3, 4]
        x = torch.rand([batch_size] + shape)
        distribution = DataParallelDistribution(StandardNormal(shape))

        log_prob = distribution.log_prob(x)
        samples = distribution.sample(batch_size)
        samples2, log_prob2 = distribution.sample_with_log_prob(batch_size)

        self.assertIsInstance(log_prob, torch.Tensor)
        self.assertIsInstance(samples, torch.Tensor)
        self.assertIsInstance(log_prob2, torch.Tensor)
        self.assertIsInstance(samples2, torch.Tensor)
Example #10
0
    def __init__(self, data_shape, num_bits,
                 base_distribution, num_scales, num_steps, actnorm, 
                 vae_hidden_units,
                 coupling_network,
                 dequant, dequant_steps, dequant_context,
                 coupling_blocks, coupling_channels, coupling_dropout,
                 coupling_gated_conv=None, coupling_depth=None, coupling_mixtures=None):

        assert len(base_distribution) == 1, "Only a single base distribution is supported"
        transforms = []
        current_shape = data_shape
        if num_steps == 0: num_scales = 0
        
        if dequant == 'uniform' or num_steps == 0 or num_scales == 0:
            # no bijective flows defaults to only using uniform dequantization
            transforms.append(UniformDequantization(num_bits=num_bits))
        elif dequant == 'flow':            
            dequantize_flow = DequantizationFlow(data_shape=data_shape,
                                                 num_bits=num_bits,
                                                 num_steps=dequant_steps,
                                                 coupling_network=coupling_network,
                                                 num_context=dequant_context,
                                                 num_blocks=coupling_blocks,
                                                 mid_channels=coupling_channels,
                                                 depth=coupling_depth,
                                                 dropout=coupling_dropout,
                                                 gated_conv=coupling_gated_conv,
                                                 num_mixtures=coupling_mixtures)
            transforms.append(VariationalDequantization(encoder=dequantize_flow, num_bits=num_bits))

        # Change range from [0,1]^D to [-0.5, 0.5]^D
        transforms.append(ScalarAffineBijection(shift=-0.5))

        for scale in range(num_scales):

            # squeeze to exchange height and width for more channels
            transforms.append(Squeeze2d())
            current_shape = (current_shape[0] * 4,
                             current_shape[1] // 2,
                             current_shape[2] // 2)

            # Dimension preserving components
            for step in range(num_steps):
                if actnorm: transforms.append(ActNormBijection2d(current_shape[0]))
                transforms.append(Conv1x1(current_shape[0]))
                if coupling_network == "conv":
                    transforms.append(
                        Coupling(in_channels=current_shape[0],
                                 num_blocks=coupling_blocks,
                                 mid_channels=coupling_channels,
                                 depth=coupling_depth,
                                 dropout=coupling_dropout,
                                 gated_conv=coupling_gated_conv,
                                 coupling_network=coupling_network))
                else:
                    transforms.append(
                        MixtureCoupling(in_channels=current_shape[0],
                                        mid_channels=coupling_channels,
                                        num_mixtures=coupling_mixtures,
                                        num_blocks=coupling_blocks,
                                        dropout=coupling_dropout))
 
            # Non-dimension preserving flows: reduce the dimensionality of data by 2 (channel-wise)
            if actnorm: transforms.append(ActNormBijection2d(current_shape[0]))
            assert current_shape[0] % 2 == 0, f"Current shape {current_shape[1]}x{current_shape[2]} must be divisible by two"
            latent_size = (current_shape[0] * current_shape[1] * current_shape[2]) // 2
            
            encoder = ConditionalNormal(
                ConvEncoderNet(in_channels=current_shape[0],
                               out_channels=latent_size,
                               mid_channels=vae_hidden_units,
                               max_pool=True, batch_norm=True),
                split_dim=1)
            decoder = ConditionalNormal(
                ConvDecoderNet(in_channels=latent_size,
                               out_shape=(current_shape[0] * 2, current_shape[1], current_shape[2]),
                               mid_channels=list(reversed(vae_hidden_units)),
                               batch_norm=True,
                               in_lambda=lambda x: x.view(x.shape[0], x.shape[1], 1, 1)),
                split_dim=1)
            
            transforms.append(VAE(encoder=encoder, decoder=decoder))
            current_shape = (current_shape[0] // 2,
                             current_shape[1],
                             current_shape[2])

            if scale < num_scales - 1:
                # reshape latent sample to have height and width
                transforms.append(Reshape(input_shape=(latent_size,), output_shape=current_shape))
            
        # Base distribution for dimension preserving portion of flow
        if base_distribution == "n":
            base_dist = StandardNormal((latent_size,))
        elif base_distribution == "c":
            base_dist = ConvNormal2d((latent_size,))
        elif base_distribution == "u":
            base_dist = StandardUniform((latent_size,))
        else:
            raise ValueError("Base distribution must be one of n=Normal, u=Uniform, or c=ConvNormal")

        # for reference save the shape output by the bijective flow
        self.latent_size = latent_size
        self.flow_shape = current_shape

        super(MultilevelCompressiveFlow, self).__init__(base_dist=[None, base_dist], transforms=transforms)
Example #11
0
latent_size = 256

encoder = ConditionalNormal(MLP(784, 2*latent_size,
                                hidden_units=[512, 256],
                                activation=None,
                                in_lambda=lambda x: 2 * x.view(x.shape[0], 784).float() - 1))
decoder = ConditionalNormal(MLP(latent_size, 784 * 2,
                                hidden_units=[256, 512],
                                activation=None,
                                out_lambda=lambda x: x.view(x.shape[0], 2, 28, 28)), split_dim=1)
# decoder = ConditionalBernoulli(MLP(latent_size, 784,
#                                    hidden_units=[256, 512],
#                                    activation=None,
#                                    out_lambda=lambda x: x.view(x.shape[0], 1, 28, 28)))

model = Flow(base_dist=StandardNormal((latent_size,)),
             transforms=[
                 UniformDequantization(num_bits=8),
                 VAE(encoder=encoder, decoder=decoder)
             ]).to(device)

print(model)

###########
## Optim ##
###########

optimizer = Adam(model.parameters(), lr=1e-3)

###########
## Train ##
Example #12
0
with open(path_args, 'rb') as f:
    args = pickle.load(f)

####################
## Specify target ##
####################

target = get_target(args)
target_id = get_target_id(args)

###################
## Specify model ##
###################

pi = get_model(args, target=target).to(args.device)
p = StandardNormal((target.size, )).to(args.device)
model_id = get_model_id(args)

state_dict = torch.load(path_check)
pi.load_state_dict(state_dict)

##############
## Sampling ##
##############

print('Sampling...')
pi = pi.eval()
with torch.no_grad():
    z = p.sample(eval_args.num_samples)
    for t in pi.transforms:
        z, _ = t(z)
Example #13
0
    splits = list(range(args.cv_folds))
    runtimes = []
else:
    splits = [args.split]

for split in splits:

    print('Split {}:'.format(split))
    target.set_split(split=split)

    ###################
    ## Specify model ##
    ###################

    pi = get_model(args, target=target, num_bits=args.num_bits).to(args.device)
    p = StandardNormal((target.size, )).to(args.device)
    model_id = get_model_id(args)

    #######################
    ## Specify optimizer ##
    #######################

    if args.optimizer == 'adam':
        optimizer = Adam(pi.parameters(), lr=args.lr)
    elif args.optimizer == 'adamax':
        optimizer = Adamax(pi.parameters(), lr=args.lr)

    ##############
    ## Training ##
    ##############
            AffineCouplingBijection(net,
                                    split_dim=2,
                                    scale_fn=scale_fn(args.scale_fn)))
    else:
        transforms.append(AdditiveCouplingBijection(net, split_dim=2))
    if args.stochperm: transforms.append(StochasticPermutation(dim=2))
    else: transforms.append(Shuffle(L, dim=2))
    return transforms


for _ in range(args.num_flows):
    if args.dimwise: transforms = dimwise(transforms)
    if args.lenwise: transforms = lenwise(transforms)
    if args.actnorm: transforms.append(ActNormBijection1d(2))

model = Flow(base_dist=StandardNormal((D, L)),
             transforms=transforms).to(args.device)
if not args.train:
    state_dict = torch.load('models/{}.pt'.format(run_name))
    model.load_state_dict(state_dict)

#######################
## Specify optimizer ##
#######################

if args.optimizer == 'adam':
    optimizer = Adam(model.parameters(), lr=args.lr)
elif args.optimizer == 'adamax':
    optimizer = Adamax(model.parameters(), lr=args.lr)

if args.warmup is not None:
Example #15
0
    def __init__(self, data_shape, cond_shape, num_bits, num_scales, num_steps,
                 actnorm, pooling, dequant, dequant_steps, dequant_context,
                 densenet_blocks, densenet_channels, densenet_depth,
                 densenet_growth, dropout, gated_conv, init_context):

        transforms = []
        current_shape = data_shape
        if dequant == 'uniform':
            transforms.append(UniformDequantization(num_bits=num_bits))
        elif dequant == 'flow':
            dequantize_flow = DequantizationFlow(
                data_shape=data_shape,
                num_bits=num_bits,
                num_steps=dequant_steps,
                num_context=dequant_context,
                num_blocks=densenet_blocks,
                mid_channels=densenet_channels,
                depth=densenet_depth,
                dropout=dropout,
                gated_conv=gated_conv)
            transforms.append(
                VariationalDequantization(encoder=dequantize_flow,
                                          num_bits=num_bits))

        # Change range from [0,1]^D to [-0.5, 0.5]^D
        transforms.append(ScalarAffineBijection(shift=-0.5))

        # Initial squeeze
        transforms.append(Squeeze2d())
        current_shape = (current_shape[0] * 4, current_shape[1] // 2,
                         current_shape[2] // 2)

        # Pooling flows
        for scale in range(num_scales):
            for step in range(num_steps):
                if actnorm:
                    transforms.append(ActNormBijection2d(current_shape[0]))
                transforms.extend([
                    Conv1x1(num_channels=current_shape[0]),
                    #ConditionalConv1x1(cond_shape=cond_shape, num_channels=current_shape[0]),  # for conditional images!
                    ConditionalCoupling(in_channels=current_shape[0],
                                        num_context=cond_shape[0],
                                        num_blocks=densenet_blocks,
                                        mid_channels=densenet_channels,
                                        depth=densenet_depth,
                                        dropout=dropout,
                                        gated_conv=gated_conv)
                ])

            if scale < num_scales - 1:
                if pooling == 'none':
                    transforms.append(Squeeze2d())
                    current_shape = (current_shape[0] * 4,
                                     current_shape[1] // 2,
                                     current_shape[2] // 2)
                else:
                    if pooling == 'slice':
                        noise_shape = (current_shape[0] * 2,
                                       current_shape[1] // 2,
                                       current_shape[2] // 2)
                        transforms.append(Squeeze2d())
                        transforms.append(
                            Slice(StandardNormal(noise_shape),
                                  num_keep=current_shape[0] * 2,
                                  dim=1))
                        current_shape = (current_shape[0] * 2,
                                         current_shape[1] // 2,
                                         current_shape[2] // 2)
                    elif pooling == 'max':
                        noise_shape = (current_shape[0] * 3,
                                       current_shape[1] // 2,
                                       current_shape[2] // 2)
                        decoder = StandardHalfNormal(noise_shape)
                        transforms.append(
                            SimpleMaxPoolSurjection2d(decoder=decoder))
                        current_shape = (current_shape[0],
                                         current_shape[1] // 2,
                                         current_shape[2] // 2)

                    else:
                        raise ValueError(
                            "pooling argument must be either slice, max or none"
                        )

            else:
                if actnorm:
                    transforms.append(ActNormBijection2d(current_shape[0]))

        # for reference save the shape output by the bijective flow
        self.flow_shape = current_shape

        super(CondPoolFlow,
              self).__init__(base_dist=ConvNormal2d(current_shape),
                             transforms=transforms)
Example #16
0
    def __init__(self,
                 data_shape,
                 cond_shape,
                 num_bits,
                 num_scales,
                 num_steps,
                 actnorm,
                 conditional_channels,
                 lowres_encoder_channels,
                 lowres_encoder_blocks,
                 lowres_encoder_depth,
                 lowres_upsampler_channels,
                 pooling,
                 compression_ratio,
                 coupling_network,
                 coupling_blocks,
                 coupling_channels,
                 coupling_dropout=0.0,
                 coupling_gated_conv=None,
                 coupling_depth=None,
                 coupling_mixtures=None,
                 dequant="flow",
                 dequant_steps=4,
                 dequant_context=32,
                 dequant_blocks=2,
                 augment_steps=4,
                 augment_context=32,
                 augment_blocks=2,
                 augment_size=None,
                 checkerboard_scales=[],
                 tuple_flip=True):

        if len(compression_ratio) == 1 and num_scales > 1:
            compression_ratio = [compression_ratio[0]] * (num_scales - 1)
        assert all([
            compression_ratio[s] >= 0.0 and compression_ratio[s] < 1.0
            for s in range(num_scales - 1)
        ])

        # initialize context. Only upsample context in ContextInit if latent shape doesn't change during the flow.
        context_init = ContextInit(num_bits=num_bits,
                                   in_channels=cond_shape[0],
                                   out_channels=lowres_encoder_channels,
                                   mid_channels=lowres_encoder_channels,
                                   num_blocks=lowres_encoder_blocks,
                                   depth=lowres_encoder_depth,
                                   dropout=coupling_dropout)

        transforms = []
        current_shape = data_shape
        if dequant == 'uniform':
            transforms.append(UniformDequantization(num_bits=num_bits))
        elif dequant == 'flow':
            dequantize_flow = DequantizationFlow(
                data_shape=data_shape,
                num_bits=num_bits,
                num_steps=dequant_steps,
                coupling_network=coupling_network,
                num_context=dequant_context,
                num_blocks=dequant_blocks,
                mid_channels=coupling_channels,
                depth=coupling_depth,
                dropout=0.0,
                gated_conv=False,
                num_mixtures=coupling_mixtures,
                checkerboard=True,
                tuple_flip=tuple_flip)
            transforms.append(
                VariationalDequantization(encoder=dequantize_flow,
                                          num_bits=num_bits))

        # Change range from [0,1]^D to [-0.5, 0.5]^D
        transforms.append(ScalarAffineBijection(shift=-0.5))

        # Initial squeezing
        if current_shape[1] >= 128 and current_shape[2] >= 128:
            # H x W -> 64 x 64
            transforms.append(Squeeze2d())
            current_shape = (current_shape[0] * 4, current_shape[1] // 2,
                             current_shape[2] // 2)

        if current_shape[1] >= 64 and current_shape[2] >= 64:
            # H x W -> 32 x 32
            transforms.append(Squeeze2d())
            current_shape = (current_shape[0] * 4, current_shape[1] // 2,
                             current_shape[2] // 2)

        if 0 not in checkerboard_scales or (current_shape[1] > 32
                                            and current_shape[2] > 32):
            # Only go to 16 x 16 if not doing checkerboard splits first
            transforms.append(Squeeze2d())
            current_shape = (current_shape[0] * 4, current_shape[1] // 2,
                             current_shape[2] // 2)

        # add in augmentation channels if desired
        if augment_size is not None and augment_size > 0:
            #transforms.append(Augment(StandardUniform((augment_size, current_shape[1], current_shape[2])), x_size=current_shape[0]))
            #transforms.append(Augment(StandardNormal((augment_size, current_shape[1], current_shape[2])), x_size=current_shape[0]))
            augment_flow = AugmentFlow(data_shape=current_shape,
                                       augment_size=augment_size,
                                       num_steps=augment_steps,
                                       coupling_network=coupling_network,
                                       mid_channels=coupling_channels,
                                       num_context=augment_context,
                                       num_mixtures=coupling_mixtures,
                                       num_blocks=augment_blocks,
                                       dropout=0.0,
                                       checkerboard=True,
                                       tuple_flip=tuple_flip)
            transforms.append(
                Augment(encoder=augment_flow, x_size=current_shape[0]))
            current_shape = (current_shape[0] + augment_size, current_shape[1],
                             current_shape[2])

        for scale in range(num_scales):

            # First and Third scales use checkerboard split pattern
            checkerboard = scale in checkerboard_scales
            context_out_channels = min(current_shape[0], coupling_channels)
            context_out_shape = (context_out_channels, current_shape[1],
                                 current_shape[2] //
                                 2) if checkerboard else (context_out_channels,
                                                          current_shape[1],
                                                          current_shape[2])

            # reshape the context to the current size for all flow steps at this scale
            context_upsampler_net = UpsamplerNet(
                in_channels=lowres_encoder_channels,
                out_shape=context_out_shape,
                mid_channels=lowres_upsampler_channels)
            transforms.append(
                ContextUpsampler(context_net=context_upsampler_net,
                                 direction='forward'))

            for step in range(num_steps):

                flip = (step % 2 == 0) if tuple_flip else False

                if len(conditional_channels) == 0:
                    if actnorm:
                        transforms.append(ActNormBijection2d(current_shape[0]))
                    transforms.append(Conv1x1(current_shape[0]))
                else:
                    if actnorm:
                        transforms.append(
                            ConditionalActNormBijection2d(
                                cond_shape=current_shape,
                                out_channels=current_shape[0],
                                mid_channels=conditional_channels))
                    transforms.append(
                        ConditionalConv1x1(cond_shape=current_shape,
                                           out_channels=current_shape[0],
                                           mid_channels=conditional_channels,
                                           slogdet_cpu=True))

                if coupling_network in ["conv", "densenet"]:
                    transforms.append(
                        SRCoupling(x_size=context_out_shape,
                                   y_size=current_shape,
                                   mid_channels=coupling_channels,
                                   depth=coupling_depth,
                                   num_blocks=coupling_blocks,
                                   dropout=coupling_dropout,
                                   gated_conv=coupling_gated_conv,
                                   coupling_network=coupling_network,
                                   checkerboard=checkerboard,
                                   flip=flip))

                elif coupling_network == "transformer":
                    transforms.append(
                        SRMixtureCoupling(x_size=context_out_shape,
                                          y_size=current_shape,
                                          mid_channels=coupling_channels,
                                          dropout=coupling_dropout,
                                          num_blocks=coupling_blocks,
                                          num_mixtures=coupling_mixtures,
                                          checkerboard=checkerboard,
                                          flip=flip))

            # Upsample context (for the previous flows, only if moving in the inverse direction)
            transforms.append(
                ContextUpsampler(context_net=context_upsampler_net,
                                 direction='inverse'))

            if scale < num_scales - 1:
                if pooling == 'none' or compression_ratio[scale] == 0.0:
                    # fully bijective flow with multi-scale architecture
                    transforms.append(Squeeze2d())
                    current_shape = (current_shape[0] * 4,
                                     current_shape[1] // 2,
                                     current_shape[2] // 2)
                elif pooling == 'slice':
                    # slice some of the dimensions (channel-wise) out from further flow steps
                    unsliced_channels = int(
                        max(
                            1, 4 * current_shape[0] *
                            (1.0 - compression_ratio[scale])))
                    sliced_channels = int(4 * current_shape[0] -
                                          unsliced_channels)
                    noise_shape = (sliced_channels, current_shape[1] // 2,
                                   current_shape[2] // 2)
                    transforms.append(Squeeze2d())
                    transforms.append(
                        Slice(StandardNormal(noise_shape),
                              num_keep=unsliced_channels,
                              dim=1))
                    current_shape = (unsliced_channels, current_shape[1] // 2,
                                     current_shape[2] // 2)
                elif pooling == 'max':
                    # max pooling to compress dimensions spatially, h//2 and w//2
                    noise_shape = (current_shape[0] * 3, current_shape[1] // 2,
                                   current_shape[2] // 2)
                    decoder = StandardHalfNormal(noise_shape)
                    transforms.append(
                        SimpleMaxPoolSurjection2d(decoder=decoder))
                    current_shape = (current_shape[0], current_shape[1] // 2,
                                     current_shape[2] // 2)
                elif pooling == "mvae":
                    # Compressive flow: reduce the dimensionality of data by 2 (channel-wise)
                    compressed_channels = max(
                        1,
                        int(current_shape[0] *
                            (1.0 - compression_ratio[scale])))
                    latent_size = compressed_channels * current_shape[
                        1] * current_shape[2]
                    vae_channels = [
                        current_shape[0] * 2, current_shape[0] * 4,
                        current_shape[0] * 8
                    ]
                    encoder = ConditionalNormal(ConvEncoderNet(
                        in_channels=current_shape[0],
                        out_channels=latent_size,
                        mid_channels=vae_channels,
                        max_pool=True,
                        batch_norm=True),
                                                split_dim=1)
                    decoder = ConditionalNormal(ConvDecoderNet(
                        in_channels=latent_size,
                        out_shape=(current_shape[0] * 2, current_shape[1],
                                   current_shape[2]),
                        mid_channels=list(reversed(vae_channels)),
                        batch_norm=True,
                        in_lambda=lambda x: x.view(x.shape[0], x.shape[1], 1, 1
                                                   )),
                                                split_dim=1)
                    transforms.append(VAE(encoder=encoder, decoder=decoder))
                    transforms.append(
                        Reshape(input_shape=(latent_size, ),
                                output_shape=(compressed_channels,
                                              current_shape[1],
                                              current_shape[2])))

                    # after reducing channels with mvae, squeeze to reshape latent space before another sequence of flows
                    transforms.append(Squeeze2d())
                    current_shape = (
                        compressed_channels * 4,  # current_shape[0] * 4
                        current_shape[1] // 2,
                        current_shape[2] // 2)

                else:
                    raise ValueError(
                        "pooling argument must be either mvae, slice, max, or none"
                    )

            else:
                if actnorm:
                    transforms.append(ActNormBijection2d(current_shape[0]))

        # for reference save the shape output by the bijective flow
        self.latent_size = current_shape[0] * current_shape[1] * current_shape[
            2]
        self.flow_shape = current_shape

        super(SRPoolFlow, self).__init__(base_dist=ConvNormal2d(current_shape),
                                         transforms=transforms,
                                         context_init=context_init)
Example #17
0
def get_model(pretrained_backbone=True, using_vae=True):
    prior = StandardNormal((2, 1, 1))
    return VAE(prior, 2, using_vae=using_vae)
Example #18
0
    def __init__(self, num_flows, actnorm, affine, scale_fn_str, hidden_units,
                 activation, range_flow, augment_size, base_dist):

        D = 2  # Number of data dimensions

        if base_dist == "uniform":
            classifier = MLP(D,
                             D // 2,
                             hidden_units=hidden_units,
                             activation=activation,
                             out_lambda=lambda x: x.view(-1))
            transforms = [
                ElementAbsSurjection(classifier=classifier),
                ShiftBijection(shift=torch.tensor([[0.0, 4.0]])),
                ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 8]]))
            ]
            base = StandardUniform((D, ))

        else:
            A = D + augment_size  # Number of augmented data dimensions
            P = 2 if affine else 1  # Number of elementwise parameters

            # initialize flow with either augmentation or Abs surjection
            if augment_size > 0:
                assert augment_size % 2 == 0
                transforms = [
                    Augment(StandardNormal((augment_size, )), x_size=D)
                ]

            else:
                transforms = [SimpleAbsSurjection()]
                if range_flow == 'logit':
                    transforms += [
                        ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 4]])),
                        Logit()
                    ]
                elif range_flow == 'softplus':
                    transforms += [SoftplusInverse()]

            # apply coupling layer flows
            for _ in range(num_flows):
                net = nn.Sequential(
                    MLP(A // 2,
                        P * A // 2,
                        hidden_units=hidden_units,
                        activation=activation), ElementwiseParams(P))

                if affine:
                    transforms.append(
                        AffineCouplingBijection(
                            net, scale_fn=scale_fn(scale_fn_str)))
                else:
                    transforms.append(AdditiveCouplingBijection(net))
                if actnorm: transforms.append(ActNormBijection(D))
                transforms.append(Reverse(A))

            transforms.pop()
            base = StandardNormal((A, ))

        super(UnconditionalFlow, self).__init__(base_dist=base,
                                                transforms=transforms)
Example #19
0
## Specify data ##
##################

train_loader, test_loader = get_data(args)

###################
## Specify model ##
###################

assert args.augdim % 2 == 0

D = 2  # Number of data dimensions
A = 2 + args.augdim  # Number of augmented data dimensions
P = 2 if args.affine else 1  # Number of elementwise parameters

transforms = [Augment(StandardNormal((args.augdim, )), x_size=D)]
for _ in range(args.num_flows):
    net = nn.Sequential(
        MLP(A // 2,
            P * A // 2,
            hidden_units=args.hidden_units,
            activation=args.activation), ElementwiseParams(P))
    if args.affine:
        transforms.append(
            AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn)))
    else:
        transforms.append(AdditiveCouplingBijection(net))
    if args.actnorm: transforms.append(ActNormBijection(D))
    transforms.append(Reverse(A))
transforms.pop()
Example #20
0

def net(channels):
    return nn.Sequential(
        DenseNet(in_channels=channels // 2,
                 out_channels=channels,
                 num_blocks=1,
                 mid_channels=64,
                 depth=8,
                 growth=16,
                 dropout=0.0,
                 gated_conv=True,
                 zero_init=True), ElementwiseParams2d(2))


model = Flow(base_dist=StandardNormal((24, 8, 8)),
             transforms=[
                 UniformDequantization(num_bits=8),
                 Augment(StandardUniform((3, 32, 32)), x_size=3),
                 AffineCouplingBijection(net(6)),
                 ActNormBijection2d(6),
                 Conv1x1(6),
                 AffineCouplingBijection(net(6)),
                 ActNormBijection2d(6),
                 Conv1x1(6),
                 AffineCouplingBijection(net(6)),
                 ActNormBijection2d(6),
                 Conv1x1(6),
                 AffineCouplingBijection(net(6)),
                 ActNormBijection2d(6),
                 Conv1x1(6),
Example #21
0
    def __init__(self,
                 data_shape,
                 num_bits,
                 base_distributions,
                 num_scales,
                 num_steps,
                 actnorm,
                 vae_hidden_units,
                 latent_size,
                 vae_activation,
                 coupling_network,
                 dequant,
                 dequant_steps,
                 dequant_context,
                 coupling_blocks,
                 coupling_channels,
                 coupling_dropout,
                 coupling_growth=None,
                 coupling_gated_conv=None,
                 coupling_depth=None,
                 coupling_mixtures=None):

        transforms = []
        current_shape = data_shape
        if num_steps == 0: num_scales = 0

        if dequant == 'uniform' or num_steps == 0 or num_scales == 0:
            # no bijective flows defaults to only using uniform dequantization
            transforms.append(UniformDequantization(num_bits=num_bits))
        elif dequant == 'flow':
            dequantize_flow = DequantizationFlow(
                data_shape=data_shape,
                num_bits=num_bits,
                num_steps=dequant_steps,
                coupling_network=coupling_network,
                num_context=dequant_context,
                num_blocks=coupling_blocks,
                mid_channels=coupling_channels,
                depth=coupling_depth,
                growth=coupling_growth,
                dropout=coupling_dropout,
                gated_conv=coupling_gated_conv,
                num_mixtures=coupling_mixtures)
            transforms.append(
                VariationalDequantization(encoder=dequantize_flow,
                                          num_bits=num_bits))

        # Change range from [0,1]^D to [-0.5, 0.5]^D
        transforms.append(ScalarAffineBijection(shift=-0.5))

        # Initial squeeze
        transforms.append(Squeeze2d())
        current_shape = (current_shape[0] * 4, current_shape[1] // 2,
                         current_shape[2] // 2)

        # Dimension preserving flows
        for scale in range(num_scales):
            for step in range(num_steps):
                if actnorm:
                    transforms.append(ActNormBijection2d(current_shape[0]))
                transforms.append(Conv1x1(current_shape[0]))
                if coupling_network in ["conv", "densenet"]:
                    transforms.append(
                        Coupling(in_channels=current_shape[0],
                                 num_blocks=coupling_blocks,
                                 mid_channels=coupling_channels,
                                 depth=coupling_depth,
                                 growth=coupling_growth,
                                 dropout=coupling_dropout,
                                 gated_conv=coupling_gated_conv,
                                 coupling_network=coupling_network))
                else:
                    transforms.append(
                        MixtureCoupling(in_channels=current_shape[0],
                                        mid_channels=coupling_channels,
                                        num_mixtures=coupling_mixtures,
                                        num_blocks=coupling_blocks,
                                        dropout=coupling_dropout))

            if scale < num_scales - 1:
                transforms.append(Squeeze2d())
                current_shape = (current_shape[0] * 4, current_shape[1] // 2,
                                 current_shape[2] // 2)
            else:
                if actnorm:
                    transforms.append(ActNormBijection2d(current_shape[0]))

        # Base distribution for dimension preserving portion of flow
        if len(base_distributions) > 1:
            if base_distributions[0] == "n":
                base0 = StandardNormal(current_shape)
            elif base_distributions[0] == "c":
                base0 = ConvNormal2d(current_shape)
            elif base_distributions[0] == "u":
                base0 = StandardUniform(current_shape)
            else:
                raise ValueError(
                    "Base distribution must be one of n=Noraml, u=Uniform, or c=ConvNormal"
                )
        else:
            base0 = None

        # for reference save the shape output by the bijective flow
        self.flow_shape = current_shape

        # Non-dimension preserving flows
        flat_dim = current_shape[0] * current_shape[1] * current_shape[2]
        encoder = ConditionalNormal(
            MLP(flat_dim,
                2 * latent_size,
                hidden_units=vae_hidden_units,
                activation=vae_activation,
                in_lambda=lambda x: x.view(x.shape[0], flat_dim)))
        decoder = ConditionalNormal(MLP(
            latent_size,
            2 * flat_dim,
            hidden_units=list(reversed(vae_hidden_units)),
            activation=vae_activation,
            out_lambda=lambda x: x.view(x.shape[0], current_shape[0] * 2,
                                        current_shape[1], current_shape[2])),
                                    split_dim=1)

        transforms.append(VAE(encoder=encoder, decoder=decoder))

        # Base distribution for non-dimension preserving portion of flow
        #self.latent_size = latent_size
        if base_distributions[-1] == "n":
            base1 = StandardNormal((latent_size, ))
        elif base_distributions[-1] == "c":
            base1 = ConvNormal2d((latent_size, ))
        elif base_distributions[-1] == "u":
            base1 = StandardUniform((latent_size, ))
        else:
            raise ValueError(
                "Base distribution must be one of n=Noraml, u=Uniform, or c=ConvNormal"
            )

        super(VAECompressiveFlow, self).__init__(base_dist=[base0, base1],
                                                 transforms=transforms)
Example #22
0
    def __init__(self, data_shape, num_bits, num_scales, num_steps, actnorm,
                 pooling, dequant, dequant_steps, dequant_context,
                 densenet_blocks, densenet_channels, densenet_depth,
                 densenet_growth, dropout, gated_conv):

        transforms = []
        current_shape = data_shape
        if dequant == 'uniform':
            transforms.append(UniformDequantization(num_bits=num_bits))
        elif dequant == 'flow':
            dequantize_flow = DequantizationFlow(
                data_shape=data_shape,
                num_bits=num_bits,
                num_steps=dequant_steps,
                num_context=dequant_context,
                num_blocks=densenet_blocks,
                mid_channels=densenet_channels,
                depth=densenet_depth,
                growth=densenet_growth,
                dropout=dropout,
                gated_conv=gated_conv)
            transforms.append(
                VariationalDequantization(encoder=dequantize_flow,
                                          num_bits=num_bits))

        # Change range from [0,1]^D to [-0.5, 0.5]^D
        transforms.append(ScalarAffineBijection(shift=-0.5))

        # Initial squeeze
        transforms.append(Squeeze2d())
        current_shape = (current_shape[0] * 4, current_shape[1] // 2,
                         current_shape[2] // 2)

        # Pooling flows
        for scale in range(num_scales):
            for step in range(num_steps):
                if actnorm:
                    transforms.append(ActNormBijection2d(current_shape[0]))
                transforms.extend([
                    Conv1x1(current_shape[0]),
                    Coupling(in_channels=current_shape[0],
                             num_blocks=densenet_blocks,
                             mid_channels=densenet_channels,
                             depth=densenet_depth,
                             growth=densenet_growth,
                             dropout=dropout,
                             gated_conv=gated_conv)
                ])

            if scale < num_scales - 1:
                noise_shape = (current_shape[0] * 3, current_shape[1] // 2,
                               current_shape[2] // 2)
                if pooling == 'none':
                    transforms.append(Squeeze2d())
                    transforms.append(
                        Slice(StandardNormal(noise_shape),
                              num_keep=current_shape[0],
                              dim=1))
                elif pooling == 'max':
                    decoder = StandardHalfNormal(noise_shape)
                    transforms.append(
                        SimpleMaxPoolSurjection2d(decoder=decoder))
                current_shape = (current_shape[0], current_shape[1] // 2,
                                 current_shape[2] // 2)
            else:
                if actnorm:
                    transforms.append(ActNormBijection2d(current_shape[0]))

        super(PoolFlow, self).__init__(base_dist=ConvNormal2d(current_shape),
                                       transforms=transforms)
Example #23
0
with open(path_args, 'rb') as f:
    args = pickle.load(f)

####################
## Specify target ##
####################

target = get_target(args)
target_id = get_target_id(args)

###################
## Specify model ##
###################

pi = get_model(args, target=target).to(args.device)
p = StandardNormal((target.size, )).to(args.device)
model_id = get_model_id(args)

state_dict = torch.load(path_check)
pi.load_state_dict(state_dict)

##############
## Training ##
##############

print('Running MCMC...')
time_before = time.time()
samples, rate = metropolis_hastings(
    pi=pi,
    num_dims=target.size,
    num_chains=eval_args.num_chains,
Example #24
0
            ActNormBijection2d(channels),  \
            Conv1x1(channels)


def reduction_layer(channels, items):
    return [
        *perm_norm_bi(channels),
        *perm_norm_bi(channels),
        *perm_norm_bi(channels),
        Squeeze2d(4),
        Slice(StandardNormal((channels * 2, items)), num_keep=channels * 2),
    ]


model = Flow(
    base_dist=StandardNormal((base_channels * (2**5), n_items // (4**4))),
    transforms=[
        UniformDequantization(num_bits=8),
        Augment(StandardUniform((base_channels * 1, n_items)),
                x_size=base_channels),
        *reduction_layer(base_channels * (2**1), n_items // (4**1)),
        *reduction_layer(base_channels * (2**2), n_items // (4**2)),
        *reduction_layer(base_channels * (2**3), n_items // (4**3)),
        *reduction_layer(base_channels * (2**4), n_items // (4**4)),
        # *reduction_layer(base_channels*(2**5), n_items//(4**4)),
        *perm_norm_bi(base_channels * (2**5))

        # AffineCouplingBijection(net(base_channels*2)), ActNormBijection2d(base_channels*2), Conv1x1(base_channels*2),
        # AffineCouplingBijection(net(base_channels*2)), ActNormBijection2d(base_channels*2), Conv1x1(base_channels*2),
        # AffineCouplingBijection(net(base_channels*2)), ActNormBijection2d(base_channels*2), Conv1x1(base_channels*2),
        # Squeeze2d(), Slice(StandardNormal((base_channels*2, n_items//4)), num_keep=base_channels*2),
Example #25
0
    def net(channels):
        return nn.Sequential(DenseNet(in_channels=channels//2,
                                      out_channels=channels,
                                      num_blocks=1,
                                      mid_channels=64,
                                      depth=1,
                                      growth=16,
                                      dropout=0.0,
                                      gated_conv=True,
                                      zero_init=True),
                             ElementwiseParams2d(2))



#model = NDPFlow(base_dist=[StandardNormal((16,7,7)), StandardNormal((latent_size,))],
model = NDPFlow(base_dist=[None, StandardNormal((latent_size,))],
                transforms=[
                    UniformDequantization(num_bits=8),
                    ActNormBijection2d(1),
                    ScalarAffineBijection(scale=2.0),
                    ScalarAffineBijection(shift=-0.5),
                    Squeeze2d(),
                    ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4)),
                    ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4)),
                    Squeeze2d(),
                    ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16)),
                    ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16)),
                    VAE(encoder=encoder, decoder=decoder)
                ]).to(device)

print(model)
Example #26
0
                activation=args.activation))
        if args.affine:
            transforms.append(
                ConditionalAffineCouplingBijection(coupling_net=net,
                                                   context_net=context_net,
                                                   scale_fn=scale_fn(
                                                       args.scale_fn)))
        else:
            transforms.append(
                ConditionalAdditiveCouplingBijection(coupling_net=net,
                                                     context_net=context_net))
        if args.actnorm: transforms.append(ActNormBijection(D))
        if args.permutation == 'reverse': transforms.append(Reverse(D))
        elif args.permutation == 'shuffle': transforms.append(Shuffle(D))
    transforms.pop()
    decoder = ConditionalFlow(base_dist=StandardNormal((D, )),
                              transforms=transforms).to(args.device)

# Flow
transforms = []
for _ in range(args.num_flows):
    net = nn.Sequential(
        MLP(D // 2,
            P * D // 2,
            hidden_units=args.hidden_units,
            activation=args.activation), ElementwiseParams(P))
    if args.affine:
        transforms.append(
            AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn)))
    else:
        transforms.append(AdditiveCouplingBijection(net))
Example #27
0
    for _ in range(args.num_flows):
        net = nn.Sequential(MLP(C+I, P*O,
                                hidden_units=args.hidden_units,
                                activation=args.activation),
                            ElementwiseParams(P))
        context_net = nn.Sequential(LambdaLayer(lambda x: 2*x.float()/(2**args.num_bits-1) - 1),
                                    MLP(D, C,
                                        hidden_units=args.hidden_units,
                                        activation=args.activation))
        if args.affine: transforms.append(ConditionalAffineCouplingBijection(coupling_net=net, context_net=context_net, scale_fn=scale_fn(args.scale_fn), num_condition=I))
        else:           transforms.append(ConditionalAdditiveCouplingBijection(coupling_net=net, context_net=context_net, num_condition=I))
        if args.actnorm: transforms.append(ActNormBijection(D))
        if args.permutation == 'reverse':   transforms.append(Reverse(D))
        elif args.permutation == 'shuffle': transforms.append(Shuffle(D))
    transforms.pop()
    decoder = ConditionalFlow(base_dist=StandardNormal((D,)), transforms=transforms).to(args.device)

# Flow
transforms = []
for _ in range(args.num_flows):
    net = nn.Sequential(MLP(I, P*O,
                            hidden_units=args.hidden_units,
                            activation=args.activation),
                        ElementwiseParams(P))
    if args.affine: transforms.append(AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn), num_condition=I))
    else:           transforms.append(AdditiveCouplingBijection(net, num_condition=I))
    if args.actnorm: transforms.append(ActNormBijection(D))
    if args.permutation == 'reverse':   transforms.append(Reverse(D))
    elif args.permutation == 'shuffle': transforms.append(Shuffle(D))
transforms.pop()
if args.num_bits is not None:
Example #28
0
    def __init__(self,
                 data_shape,
                 num_bits,
                 num_scales,
                 num_steps,
                 actnorm,
                 pooling,
                 compression_ratio,
                 coupling_network,
                 coupling_blocks,
                 coupling_channels,
                 coupling_dropout=0.0,
                 coupling_gated_conv=None,
                 coupling_depth=None,
                 coupling_mixtures=None,
                 dequant="flow",
                 dequant_steps=4,
                 dequant_context=32,
                 dequant_blocks=2,
                 augment_steps=4,
                 augment_context=32,
                 augment_blocks=2,
                 augment_size=None,
                 checkerboard_scales=[],
                 tuple_flip=True):

        if len(compression_ratio) == 1 and num_scales > 1:
            compression_ratio = [compression_ratio[0]] * (num_scales - 1)
        assert all([
            compression_ratio[s] >= 0 and compression_ratio[s] < 1
            for s in range(num_scales - 1)
        ])

        transforms = []
        current_shape = data_shape
        if dequant == 'uniform':
            transforms.append(UniformDequantization(num_bits=num_bits))
        elif dequant == 'flow':
            dequantize_flow = DequantizationFlow(
                data_shape=data_shape,
                num_bits=num_bits,
                num_steps=dequant_steps,
                coupling_network=coupling_network,
                num_context=dequant_context,
                num_blocks=dequant_blocks,
                mid_channels=coupling_channels,
                depth=coupling_depth,
                dropout=0.0,
                gated_conv=False,
                num_mixtures=coupling_mixtures,
                checkerboard=True,
                tuple_flip=tuple_flip)
            transforms.append(
                VariationalDequantization(encoder=dequantize_flow,
                                          num_bits=num_bits))

        # Change range from [0,1]^D to [-0.5, 0.5]^D
        transforms.append(ScalarAffineBijection(shift=-0.5))

        # Initial squeezing
        if current_shape[1] >= 128 and current_shape[2] >= 128:
            # H x W -> 64 x 64
            transforms.append(Squeeze2d())
            current_shape = (current_shape[0] * 4, current_shape[1] // 2,
                             current_shape[2] // 2)

        if current_shape[1] >= 64 and current_shape[2] >= 64:
            # H x W -> 32 x 32
            transforms.append(Squeeze2d())
            current_shape = (current_shape[0] * 4, current_shape[1] // 2,
                             current_shape[2] // 2)

        if 0 not in checkerboard_scales or (current_shape[1] > 32
                                            and current_shape[2] > 32):
            # Only go to 16 x 16 if not doing checkerboard splits first
            transforms.append(Squeeze2d())
            current_shape = (current_shape[0] * 4, current_shape[1] // 2,
                             current_shape[2] // 2)

        # add in augmentation channels if desired
        if augment_size is not None and augment_size > 0:
            augment_flow = AugmentFlow(data_shape=current_shape,
                                       augment_size=augment_size,
                                       num_steps=augment_steps,
                                       coupling_network=coupling_network,
                                       mid_channels=coupling_channels,
                                       num_context=augment_context,
                                       num_mixtures=coupling_mixtures,
                                       num_blocks=augment_blocks,
                                       dropout=0.0,
                                       checkerboard=True,
                                       tuple_flip=tuple_flip)
            transforms.append(
                Augment(encoder=augment_flow, x_size=current_shape[0]))
            current_shape = (current_shape[0] + augment_size, current_shape[1],
                             current_shape[2])

        for scale in range(num_scales):
            checkerboard = scale in checkerboard_scales

            for step in range(num_steps):
                flip = (step % 2 == 0) if tuple_flip else False

                if actnorm:
                    transforms.append(ActNormBijection2d(current_shape[0]))
                transforms.append(Conv1x1(current_shape[0]))

                if coupling_network == "conv":
                    transforms.append(
                        Coupling(in_channels=current_shape[0],
                                 num_blocks=coupling_blocks,
                                 mid_channels=coupling_channels,
                                 depth=coupling_depth,
                                 dropout=coupling_dropout,
                                 gated_conv=coupling_gated_conv,
                                 coupling_network=coupling_network,
                                 checkerboard=checkerboard,
                                 flip=flip))
                else:
                    transforms.append(
                        MixtureCoupling(in_channels=current_shape[0],
                                        mid_channels=coupling_channels,
                                        num_mixtures=coupling_mixtures,
                                        num_blocks=coupling_blocks,
                                        dropout=coupling_dropout,
                                        checkerboard=checkerboard,
                                        flip=flip))

            if scale < num_scales - 1:
                if pooling in ['bijective', 'none'
                               ] or compression_ratio[scale] == 0.0:
                    transforms.append(Squeeze2d())
                    current_shape = (current_shape[0] * 4,
                                     current_shape[1] // 2,
                                     current_shape[2] // 2)
                elif pooling == 'slice':
                    # slice some of the dimensions (channel-wise) out from further flow steps
                    unsliced_channels = int(
                        max(1, 4 * current_shape[0] *
                            (1.0 - sliced_ratio[scale])))
                    sliced_channels = int(4 * current_shape[0] -
                                          unsliced_channels)
                    noise_shape = (sliced_channels, current_shape[1] // 2,
                                   current_shape[2] // 2)
                    transforms.append(Squeeze2d())
                    transforms.append(
                        Slice(StandardNormal(noise_shape),
                              num_keep=unsliced_channels,
                              dim=1))
                    current_shape = (unsliced_channels, current_shape[1] // 2,
                                     current_shape[2] // 2)
                elif pooling == 'max':
                    noise_shape = (current_shape[0] * 3, current_shape[1] // 2,
                                   current_shape[2] // 2)
                    decoder = StandardHalfNormal(noise_shape)
                    transforms.append(
                        SimpleMaxPoolSurjection2d(decoder=decoder))
                    current_shape = (current_shape[0], current_shape[1] // 2,
                                     current_shape[2] // 2)
                else:
                    raise ValueError(
                        f"Pooling argument must be either slice, max or none, not: {pooling}"
                    )

            else:
                if actnorm:
                    transforms.append(ActNormBijection2d(current_shape[0]))

        # for reference save the shape output by the bijective flow
        self.flow_shape = current_shape
        self.latent_size = current_shape[0] * current_shape[1] * current_shape[
            2]

        super(PoolFlow, self).__init__(base_dist=ConvNormal2d(current_shape),
                                       transforms=transforms)