Ejemplo n.º 1
0
    def test_distribution_is_well_behaved(self):
        batch_size = 10
        shape = [2, 3, 4]
        x = torch.randn(batch_size, *shape)
        distribution = StandardUniform(shape)

        self.assert_distribution_is_well_behaved(distribution,
                                                 x,
                                                 expected_shape=(batch_size,
                                                                 *shape))
Ejemplo n.º 2
0
    return nn.Sequential(
        DenseNet(in_channels=channels // 2,
                 out_channels=channels,
                 num_blocks=1,
                 mid_channels=64,
                 depth=8,
                 growth=16,
                 dropout=0.0,
                 gated_conv=True,
                 zero_init=True), ElementwiseParams2d(2))


model = Flow(base_dist=StandardNormal((24, 8, 8)),
             transforms=[
                 UniformDequantization(num_bits=8),
                 Augment(StandardUniform((3, 32, 32)), x_size=3),
                 AffineCouplingBijection(net(6)),
                 ActNormBijection2d(6),
                 Conv1x1(6),
                 AffineCouplingBijection(net(6)),
                 ActNormBijection2d(6),
                 Conv1x1(6),
                 AffineCouplingBijection(net(6)),
                 ActNormBijection2d(6),
                 Conv1x1(6),
                 AffineCouplingBijection(net(6)),
                 ActNormBijection2d(6),
                 Conv1x1(6),
                 Squeeze2d(),
                 Slice(StandardNormal((12, 16, 16)), num_keep=12),
                 AffineCouplingBijection(net(12)),
Ejemplo n.º 3
0
    def __init__(self, data_shape, num_bits,
                 base_distribution, num_scales, num_steps, actnorm, 
                 vae_hidden_units,
                 coupling_network,
                 dequant, dequant_steps, dequant_context,
                 coupling_blocks, coupling_channels, coupling_dropout,
                 coupling_gated_conv=None, coupling_depth=None, coupling_mixtures=None):

        assert len(base_distribution) == 1, "Only a single base distribution is supported"
        transforms = []
        current_shape = data_shape
        if num_steps == 0: num_scales = 0
        
        if dequant == 'uniform' or num_steps == 0 or num_scales == 0:
            # no bijective flows defaults to only using uniform dequantization
            transforms.append(UniformDequantization(num_bits=num_bits))
        elif dequant == 'flow':            
            dequantize_flow = DequantizationFlow(data_shape=data_shape,
                                                 num_bits=num_bits,
                                                 num_steps=dequant_steps,
                                                 coupling_network=coupling_network,
                                                 num_context=dequant_context,
                                                 num_blocks=coupling_blocks,
                                                 mid_channels=coupling_channels,
                                                 depth=coupling_depth,
                                                 dropout=coupling_dropout,
                                                 gated_conv=coupling_gated_conv,
                                                 num_mixtures=coupling_mixtures)
            transforms.append(VariationalDequantization(encoder=dequantize_flow, num_bits=num_bits))

        # Change range from [0,1]^D to [-0.5, 0.5]^D
        transforms.append(ScalarAffineBijection(shift=-0.5))

        for scale in range(num_scales):

            # squeeze to exchange height and width for more channels
            transforms.append(Squeeze2d())
            current_shape = (current_shape[0] * 4,
                             current_shape[1] // 2,
                             current_shape[2] // 2)

            # Dimension preserving components
            for step in range(num_steps):
                if actnorm: transforms.append(ActNormBijection2d(current_shape[0]))
                transforms.append(Conv1x1(current_shape[0]))
                if coupling_network == "conv":
                    transforms.append(
                        Coupling(in_channels=current_shape[0],
                                 num_blocks=coupling_blocks,
                                 mid_channels=coupling_channels,
                                 depth=coupling_depth,
                                 dropout=coupling_dropout,
                                 gated_conv=coupling_gated_conv,
                                 coupling_network=coupling_network))
                else:
                    transforms.append(
                        MixtureCoupling(in_channels=current_shape[0],
                                        mid_channels=coupling_channels,
                                        num_mixtures=coupling_mixtures,
                                        num_blocks=coupling_blocks,
                                        dropout=coupling_dropout))
 
            # Non-dimension preserving flows: reduce the dimensionality of data by 2 (channel-wise)
            if actnorm: transforms.append(ActNormBijection2d(current_shape[0]))
            assert current_shape[0] % 2 == 0, f"Current shape {current_shape[1]}x{current_shape[2]} must be divisible by two"
            latent_size = (current_shape[0] * current_shape[1] * current_shape[2]) // 2
            
            encoder = ConditionalNormal(
                ConvEncoderNet(in_channels=current_shape[0],
                               out_channels=latent_size,
                               mid_channels=vae_hidden_units,
                               max_pool=True, batch_norm=True),
                split_dim=1)
            decoder = ConditionalNormal(
                ConvDecoderNet(in_channels=latent_size,
                               out_shape=(current_shape[0] * 2, current_shape[1], current_shape[2]),
                               mid_channels=list(reversed(vae_hidden_units)),
                               batch_norm=True,
                               in_lambda=lambda x: x.view(x.shape[0], x.shape[1], 1, 1)),
                split_dim=1)
            
            transforms.append(VAE(encoder=encoder, decoder=decoder))
            current_shape = (current_shape[0] // 2,
                             current_shape[1],
                             current_shape[2])

            if scale < num_scales - 1:
                # reshape latent sample to have height and width
                transforms.append(Reshape(input_shape=(latent_size,), output_shape=current_shape))
            
        # Base distribution for dimension preserving portion of flow
        if base_distribution == "n":
            base_dist = StandardNormal((latent_size,))
        elif base_distribution == "c":
            base_dist = ConvNormal2d((latent_size,))
        elif base_distribution == "u":
            base_dist = StandardUniform((latent_size,))
        else:
            raise ValueError("Base distribution must be one of n=Normal, u=Uniform, or c=ConvNormal")

        # for reference save the shape output by the bijective flow
        self.latent_size = latent_size
        self.flow_shape = current_shape

        super(MultilevelCompressiveFlow, self).__init__(base_dist=[None, base_dist], transforms=transforms)
Ejemplo n.º 4
0
##################
## Specify data ##
##################

train_loader, test_loader = get_data(args)

###################
## Specify model ##
###################

classifier = MLP(2, 1,
                 hidden_units=args.hidden_units,
                 activation=args.activation,
                 out_lambda=lambda x: x.view(-1))

model = Flow(base_dist=StandardUniform((2,)),
             transforms=[
                ElementAbsSurjection(classifier=classifier),
                ShiftBijection(shift=torch.tensor([[0.0, 4.0]])),
                ScaleBijection(scale=torch.tensor([[1/4, 1/8]]))
                        ]).to(args.device)

#######################
## Specify optimizer ##
#######################

if args.optimizer == 'adam':
    optimizer = Adam(model.parameters(), lr=args.lr)
elif args.optimizer == 'adamax':
    optimizer = Adamax(model.parameters(), lr=args.lr)
Ejemplo n.º 5
0
    def __init__(self,
                 data_shape,
                 num_bits,
                 base_distributions,
                 num_scales,
                 num_steps,
                 actnorm,
                 vae_hidden_units,
                 latent_size,
                 vae_activation,
                 coupling_network,
                 dequant,
                 dequant_steps,
                 dequant_context,
                 coupling_blocks,
                 coupling_channels,
                 coupling_dropout,
                 coupling_growth=None,
                 coupling_gated_conv=None,
                 coupling_depth=None,
                 coupling_mixtures=None):

        transforms = []
        current_shape = data_shape
        if num_steps == 0: num_scales = 0

        if dequant == 'uniform' or num_steps == 0 or num_scales == 0:
            # no bijective flows defaults to only using uniform dequantization
            transforms.append(UniformDequantization(num_bits=num_bits))
        elif dequant == 'flow':
            dequantize_flow = DequantizationFlow(
                data_shape=data_shape,
                num_bits=num_bits,
                num_steps=dequant_steps,
                coupling_network=coupling_network,
                num_context=dequant_context,
                num_blocks=coupling_blocks,
                mid_channels=coupling_channels,
                depth=coupling_depth,
                growth=coupling_growth,
                dropout=coupling_dropout,
                gated_conv=coupling_gated_conv,
                num_mixtures=coupling_mixtures)
            transforms.append(
                VariationalDequantization(encoder=dequantize_flow,
                                          num_bits=num_bits))

        # Change range from [0,1]^D to [-0.5, 0.5]^D
        transforms.append(ScalarAffineBijection(shift=-0.5))

        # Initial squeeze
        transforms.append(Squeeze2d())
        current_shape = (current_shape[0] * 4, current_shape[1] // 2,
                         current_shape[2] // 2)

        # Dimension preserving flows
        for scale in range(num_scales):
            for step in range(num_steps):
                if actnorm:
                    transforms.append(ActNormBijection2d(current_shape[0]))
                transforms.append(Conv1x1(current_shape[0]))
                if coupling_network in ["conv", "densenet"]:
                    transforms.append(
                        Coupling(in_channels=current_shape[0],
                                 num_blocks=coupling_blocks,
                                 mid_channels=coupling_channels,
                                 depth=coupling_depth,
                                 growth=coupling_growth,
                                 dropout=coupling_dropout,
                                 gated_conv=coupling_gated_conv,
                                 coupling_network=coupling_network))
                else:
                    transforms.append(
                        MixtureCoupling(in_channels=current_shape[0],
                                        mid_channels=coupling_channels,
                                        num_mixtures=coupling_mixtures,
                                        num_blocks=coupling_blocks,
                                        dropout=coupling_dropout))

            if scale < num_scales - 1:
                transforms.append(Squeeze2d())
                current_shape = (current_shape[0] * 4, current_shape[1] // 2,
                                 current_shape[2] // 2)
            else:
                if actnorm:
                    transforms.append(ActNormBijection2d(current_shape[0]))

        # Base distribution for dimension preserving portion of flow
        if len(base_distributions) > 1:
            if base_distributions[0] == "n":
                base0 = StandardNormal(current_shape)
            elif base_distributions[0] == "c":
                base0 = ConvNormal2d(current_shape)
            elif base_distributions[0] == "u":
                base0 = StandardUniform(current_shape)
            else:
                raise ValueError(
                    "Base distribution must be one of n=Noraml, u=Uniform, or c=ConvNormal"
                )
        else:
            base0 = None

        # for reference save the shape output by the bijective flow
        self.flow_shape = current_shape

        # Non-dimension preserving flows
        flat_dim = current_shape[0] * current_shape[1] * current_shape[2]
        encoder = ConditionalNormal(
            MLP(flat_dim,
                2 * latent_size,
                hidden_units=vae_hidden_units,
                activation=vae_activation,
                in_lambda=lambda x: x.view(x.shape[0], flat_dim)))
        decoder = ConditionalNormal(MLP(
            latent_size,
            2 * flat_dim,
            hidden_units=list(reversed(vae_hidden_units)),
            activation=vae_activation,
            out_lambda=lambda x: x.view(x.shape[0], current_shape[0] * 2,
                                        current_shape[1], current_shape[2])),
                                    split_dim=1)

        transforms.append(VAE(encoder=encoder, decoder=decoder))

        # Base distribution for non-dimension preserving portion of flow
        #self.latent_size = latent_size
        if base_distributions[-1] == "n":
            base1 = StandardNormal((latent_size, ))
        elif base_distributions[-1] == "c":
            base1 = ConvNormal2d((latent_size, ))
        elif base_distributions[-1] == "u":
            base1 = StandardUniform((latent_size, ))
        else:
            raise ValueError(
                "Base distribution must be one of n=Noraml, u=Uniform, or c=ConvNormal"
            )

        super(VAECompressiveFlow, self).__init__(base_dist=[base0, base1],
                                                 transforms=transforms)
Ejemplo n.º 6
0
    def __init__(self, num_flows, actnorm, affine, scale_fn_str, hidden_units,
                 activation, range_flow, augment_size, base_dist):

        D = 2  # Number of data dimensions

        if base_dist == "uniform":
            classifier = MLP(D,
                             D // 2,
                             hidden_units=hidden_units,
                             activation=activation,
                             out_lambda=lambda x: x.view(-1))
            transforms = [
                ElementAbsSurjection(classifier=classifier),
                ShiftBijection(shift=torch.tensor([[0.0, 4.0]])),
                ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 8]]))
            ]
            base = StandardUniform((D, ))

        else:
            A = D + augment_size  # Number of augmented data dimensions
            P = 2 if affine else 1  # Number of elementwise parameters

            # initialize flow with either augmentation or Abs surjection
            if augment_size > 0:
                assert augment_size % 2 == 0
                transforms = [
                    Augment(StandardNormal((augment_size, )), x_size=D)
                ]

            else:
                transforms = [SimpleAbsSurjection()]
                if range_flow == 'logit':
                    transforms += [
                        ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 4]])),
                        Logit()
                    ]
                elif range_flow == 'softplus':
                    transforms += [SoftplusInverse()]

            # apply coupling layer flows
            for _ in range(num_flows):
                net = nn.Sequential(
                    MLP(A // 2,
                        P * A // 2,
                        hidden_units=hidden_units,
                        activation=activation), ElementwiseParams(P))

                if affine:
                    transforms.append(
                        AffineCouplingBijection(
                            net, scale_fn=scale_fn(scale_fn_str)))
                else:
                    transforms.append(AdditiveCouplingBijection(net))
                if actnorm: transforms.append(ActNormBijection(D))
                transforms.append(Reverse(A))

            transforms.pop()
            base = StandardNormal((A, ))

        super(UnconditionalFlow, self).__init__(base_dist=base,
                                                transforms=transforms)
Ejemplo n.º 7
0
def reduction_layer(channels, items):
    return [
        *perm_norm_bi(channels),
        *perm_norm_bi(channels),
        *perm_norm_bi(channels),
        Squeeze2d(4),
        Slice(StandardNormal((channels * 2, items)), num_keep=channels * 2),
    ]


model = Flow(
    base_dist=StandardNormal((base_channels * (2**5), n_items // (4**4))),
    transforms=[
        UniformDequantization(num_bits=8),
        Augment(StandardUniform((base_channels * 1, n_items)),
                x_size=base_channels),
        *reduction_layer(base_channels * (2**1), n_items // (4**1)),
        *reduction_layer(base_channels * (2**2), n_items // (4**2)),
        *reduction_layer(base_channels * (2**3), n_items // (4**3)),
        *reduction_layer(base_channels * (2**4), n_items // (4**4)),
        # *reduction_layer(base_channels*(2**5), n_items//(4**4)),
        *perm_norm_bi(base_channels * (2**5))

        # AffineCouplingBijection(net(base_channels*2)), ActNormBijection2d(base_channels*2), Conv1x1(base_channels*2),
        # AffineCouplingBijection(net(base_channels*2)), ActNormBijection2d(base_channels*2), Conv1x1(base_channels*2),
        # AffineCouplingBijection(net(base_channels*2)), ActNormBijection2d(base_channels*2), Conv1x1(base_channels*2),
        # Squeeze2d(), Slice(StandardNormal((base_channels*2, n_items//4)), num_keep=base_channels*2),
        # AffineCouplingBijection(net(base_channels*2)), ActNormBijection2d(base_channels*2), Conv1x1(base_channels*2),
        # AffineCouplingBijection(net(base_channels*2)), ActNormBijection2d(base_channels*2), Conv1x1(base_channels*2),
        # AffineCouplingBijection(net(base_channels*2)), ActNormBijection2d(base_channels*2), Conv1x1(base_channels*2),