Beispiel #1
0
    def __init__(self,
                 in_channels,
                 mid_channels,
                 num_mixtures,
                 num_blocks,
                 dropout,
                 checkerboard=False,
                 flip=False):

        if checkerboard:
            num_in = in_channels
            split_dim = 3
        else:
            num_in = in_channels // 2
            split_dim = 1

        net = nn.Sequential(
            TransformerNet(in_channels=num_in,
                           mid_channels=mid_channels,
                           num_blocks=num_blocks,
                           num_mixtures=num_mixtures,
                           dropout=dropout),
            ElementwiseParams2d(2 + num_mixtures * 3, mode='sequential'))

        super(MixtureCoupling, self).__init__(coupling_net=net,
                                              num_mixtures=num_mixtures,
                                              scale_fn=scale_fn("tanh_exp"),
                                              split_dim=split_dim,
                                              flip=flip)
Beispiel #2
0
    def __init__(self,
                 x_size,
                 y_size,
                 mid_channels,
                 num_blocks,
                 num_mixtures,
                 dropout,
                 checkerboard=False,
                 flip=False):

        if checkerboard:
            in_channels = y_size[0]
            split_dim = 3
            assert x_size[1] == y_size[1] and x_size[2] == y_size[2] // 2
        else:
            in_channels = y_size[0] // 2
            split_dim = 1
            assert x_size[1] == y_size[1] and x_size[2] == y_size[2]
            assert y_size[
                0] % 2 == 0, f"High-resolution has shape {y_size} with channels not evenly divisible"

        coupling_net = nn.Sequential(
            TransformerNet(in_channels=in_channels,
                           context_channels=x_size[0],
                           mid_channels=mid_channels,
                           num_blocks=num_blocks,
                           num_mixtures=num_mixtures,
                           dropout=dropout),
            ElementwiseParams2d(2 + num_mixtures * 3, mode='sequential'))

        super(SRMixtureCoupling, self).__init__(coupling_net=coupling_net,
                                                num_mixtures=num_mixtures,
                                                scale_fn=scale_fn("tanh_exp"),
                                                split_dim=split_dim,
                                                flip=flip)
Beispiel #3
0
    def __init__(self, in_channels, num_context, num_blocks, mid_channels,
                 depth, growth, dropout, gated_conv, coupling_network):

        assert in_channels % 2 == 0

        if coupling_network == "densenet":
            net = nn.Sequential(
                DenseNet(in_channels=in_channels // 2 + num_context,
                         out_channels=in_channels,
                         num_blocks=num_blocks,
                         mid_channels=mid_channels,
                         depth=depth,
                         growth=growth,
                         dropout=dropout,
                         gated_conv=gated_conv,
                         zero_init=True),
                ElementwiseParams2d(2, mode='sequential'))
        elif coupling_network == "conv":
            net = nn.Sequential(
                ConvNet(in_channels=in_channels // 2 + num_context,
                        out_channels=in_channels,
                        mid_channels=mid_channels,
                        num_layers=depth,
                        activation='relu'),
                ElementwiseParams2d(2, mode='sequential'))
        else:
            raise ValueError(f"Unknown coupling network {coupling_network}")

        super(ConditionalCoupling,
              self).__init__(coupling_net=net, scale_fn=scale_fn("tanh_exp"))
Beispiel #4
0
    def __init__(self,
                 x_size,
                 y_size,
                 coupling_network,
                 mid_channels,
                 depth,
                 num_blocks=None,
                 dropout=None,
                 gated_conv=None,
                 checkerboard=False,
                 flip=False):

        if checkerboard:
            in_channels = y_size[0] + x_size[0]
            out_channels = y_size[0] * 2
            split_dim = 3
            assert x_size[1] == y_size[1] and x_size[2] == y_size[2] // 2
        else:
            in_channels = y_size[0] // 2 + x_size[0]
            out_channels = y_size[0]
            split_dim = 1
            assert x_size[1] == y_size[1] and x_size[2] == y_size[2]
            assert y_size[
                0] % 2 == 0, f"High-resolution has shape {y_size} with channels not evenly divisible"

        if coupling_network == "densenet":
            coupling_net = nn.Sequential(
                DenseNet(in_channels=in_channels,
                         out_channels=out_channels,
                         num_blocks=num_blocks,
                         mid_channels=mid_channels,
                         depth=depth,
                         growth=mid_channels,
                         dropout=dropout,
                         gated_conv=gated_conv,
                         zero_init=True),
                ElementwiseParams2d(2, mode='sequential'))

        elif coupling_network == "conv":
            coupling_net = nn.Sequential(
                ConvNet(in_channels=in_channels,
                        out_channels=out_channels,
                        mid_channels=mid_channels,
                        num_layers=depth,
                        weight_norm=True,
                        activation='relu'),
                ElementwiseParams2d(2, mode='sequential'))

        else:
            raise ValueError(f"Unknown coupling network {coupling_network}")

        super(SRCoupling, self).__init__(coupling_net=coupling_net,
                                         scale_fn=scale_fn("tanh_exp"),
                                         split_dim=split_dim,
                                         flip=flip)
Beispiel #5
0
    def __init__(self, num_flows, actnorm, affine, scale_fn_str, hidden_units,
                 activation, range_flow, augment_size, base_dist, cond_size):

        D = 2  # Number of data dimensions
        A = D + augment_size  # Number of augmented data dimensions
        P = 2 if affine else 1  # Number of elementwise parameters

        # initialize context. Only upsample context in ContextInit if latent shape doesn't change during the flow.
        context_init = MLP(input_size=cond_size,
                           output_size=D,
                           hidden_units=hidden_units,
                           activation=activation)

        # initialize flow with either augmentation or Abs surjection
        if augment_size > 0:
            assert augment_size % 2 == 0
            transforms = [Augment(StandardNormal((augment_size, )), x_size=D)]

        else:
            transforms = []
            transforms = [SimpleAbsSurjection()]
            if range_flow == 'logit':
                transforms += [
                    ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 4]])),
                    Logit()
                ]
            elif range_flow == 'softplus':
                transforms += [SoftplusInverse()]

        # apply coupling layer flows
        for _ in range(num_flows):
            net = nn.Sequential(
                MLP(A // 2 + D,
                    P * A // 2,
                    hidden_units=hidden_units,
                    activation=activation), ElementwiseParams(P))
            if affine:
                transforms.append(
                    ConditionalAffineCouplingBijection(
                        net, scale_fn=scale_fn(scale_fn_str)))
            else:
                transforms.append(ConditionalAdditiveCouplingBijection(net))
            if actnorm: transforms.append(ActNormBijection(D))
            transforms.append(Reverse(A))

        transforms.pop()

        if base_dist == "uniform":
            base = StandardUniform((A, ))
        else:
            base = StandardNormal((A, ))

        super(SRFlow, self).__init__(base_dist=base,
                                     transforms=transforms,
                                     context_init=context_init)
Beispiel #6
0
    def __init__(self,
                 in_channels,
                 num_context,
                 num_blocks,
                 mid_channels,
                 depth,
                 dropout,
                 gated_conv,
                 coupling_network,
                 checkerboard=False,
                 flip=False):

        if checkerboard:
            num_in = in_channels + num_context
            num_out = in_channels * 2
            split_dim = 3
        else:
            num_in = in_channels // 2 + num_context
            num_out = in_channels
            split_dim = 1

        assert in_channels % 2 == 0 or split_dim != 1, f"in_channels = {in_channels} not evenly divisible"

        if coupling_network == "densenet":
            net = nn.Sequential(
                DenseNet(in_channels=num_in,
                         out_channels=num_out,
                         num_blocks=num_blocks,
                         mid_channels=mid_channels,
                         depth=depth,
                         growth=mid_channels,
                         dropout=dropout,
                         gated_conv=gated_conv,
                         zero_init=True),
                ElementwiseParams2d(2, mode='sequential'))
        elif coupling_network == "conv":
            net = nn.Sequential(
                ConvNet(in_channels=num_in,
                        out_channels=num_out,
                        mid_channels=mid_channels,
                        num_layers=depth,
                        activation='relu'),
                ElementwiseParams2d(2, mode='sequential'))
        else:
            raise ValueError(f"Unknown coupling network {coupling_network}")

        super(ConditionalCoupling,
              self).__init__(coupling_net=net,
                             scale_fn=scale_fn("tanh_exp"),
                             split_dim=split_dim,
                             flip=flip)
Beispiel #7
0
    def __init__(self, in_channels, mid_channels, num_mixtures, num_blocks,
                 dropout):

        net = nn.Sequential(
            TransformerNet(in_channels // 2,
                           mid_channels=mid_channels,
                           num_blocks=num_blocks,
                           num_mixtures=num_mixtures,
                           dropout=dropout),
            ElementwiseParams2d(2 + num_mixtures * 3, mode='sequential'))

        super(MixtureCoupling, self).__init__(coupling_net=net,
                                              num_mixtures=num_mixtures,
                                              scale_fn=scale_fn("tanh_exp"))
    def test_bijection_is_well_behaved(self):
        batch_size = 10

        self.eps = 5e-6
        for scale_str in ['exp', 'softplus', 'sigmoid', 'tanh_exp']:
            for shape in [(6, ), (6, 8, 8)]:
                for num_condition in [None, 1]:
                    with self.subTest(shape=shape,
                                      num_condition=num_condition,
                                      scale_str=scale_str):
                        x = torch.randn(batch_size, *shape)
                        context = torch.randn(batch_size, *shape)
                        if num_condition is None:
                            if len(shape) == 1:
                                net = nn.Sequential(nn.Linear(3 + 6, 3 * 2),
                                                    ElementwiseParams(2))
                            if len(shape) == 3:
                                net = nn.Sequential(
                                    nn.Conv2d(3 + 6,
                                              3 * 2,
                                              kernel_size=3,
                                              padding=1),
                                    ElementwiseParams2d(2))
                        else:
                            if len(shape) == 1:
                                net = nn.Sequential(nn.Linear(1 + 6, 5 * 2),
                                                    ElementwiseParams(2))
                            if len(shape) == 3:
                                net = nn.Sequential(
                                    nn.Conv2d(1 + 6,
                                              5 * 2,
                                              kernel_size=3,
                                              padding=1),
                                    ElementwiseParams2d(2))
                        bijection = ConditionalAffineCouplingBijection(
                            net,
                            num_condition=num_condition,
                            scale_fn=scale_fn(scale_str))
                        self.assert_bijection_is_well_behaved(
                            bijection,
                            x,
                            context,
                            z_shape=(batch_size, *shape))
                        z, _ = bijection.forward(x, context)
                        if num_condition is None:
                            self.assertEqual(x[:, :3], z[:, :3])
                        else:
                            self.assertEqual(x[:, :1], z[:, :1])
def dimwise(transforms):
    net = nn.Sequential(
        DenseTransformer(d_input=D // 2,
                         d_output=P * D // 2,
                         d_model=args.d_model,
                         nhead=args.nhead,
                         num_layers=args.num_layers,
                         dim_feedforward=4 * args.d_model,
                         dropout=args.dropout,
                         activation=args.activation,
                         checkpoint_blocks=args.checkpoint_blocks),
        ElementwiseParams1d(P))
    if args.affine:
        transforms.append(
            AffineCouplingBijection(net,
                                    split_dim=1,
                                    scale_fn=scale_fn(args.scale_fn)))
    else:
        transforms.append(AdditiveCouplingBijection(net, split_dim=1))
    transforms.append(Reverse(D, dim=1))
    return transforms
Beispiel #10
0
    def __init__(self,
                 in_channels,
                 num_context,
                 mid_channels,
                 num_mixtures,
                 num_blocks,
                 dropout,
                 use_attn=True):
        coupling_net = nn.Sequential(
            TransformerNet(in_channels // 2,
                           context_channels=num_context,
                           mid_channels=mid_channels,
                           num_blocks=num_blocks,
                           num_mixtures=num_mixtures,
                           use_attn=use_attn,
                           dropout=dropout),
            ElementwiseParams2d(2 + num_mixtures * 3, mode='sequential'))

        super(ConditionalMixtureCoupling,
              self).__init__(coupling_net=coupling_net,
                             num_mixtures=num_mixtures,
                             scale_fn=scale_fn("tanh_exp"))
def lenwise(transforms):
    net = nn.Sequential(
        DenseTransformer(d_input=D,
                         d_output=P * D,
                         d_model=args.d_model,
                         nhead=args.nhead,
                         num_layers=args.num_layers,
                         dim_feedforward=4 * args.d_model,
                         dropout=args.dropout,
                         activation=args.activation,
                         checkpoint_blocks=args.checkpoint_blocks),
        ElementwiseParams1d(P))
    if args.affine:
        transforms.append(
            AffineCouplingBijection(net,
                                    split_dim=2,
                                    scale_fn=scale_fn(args.scale_fn)))
    else:
        transforms.append(AdditiveCouplingBijection(net, split_dim=2))
    if args.stochperm: transforms.append(StochasticPermutation(dim=2))
    else: transforms.append(Shuffle(L, dim=2))
    return transforms
Beispiel #12
0
I = D // 2
O = D // 2 + D % 2

# Decoder
if args.num_bits is not None:
    transforms = [Logit()]
    for _ in range(args.num_flows):
        net = nn.Sequential(MLP(C+I, P*O,
                                hidden_units=args.hidden_units,
                                activation=args.activation),
                            ElementwiseParams(P))
        context_net = nn.Sequential(LambdaLayer(lambda x: 2*x.float()/(2**args.num_bits-1) - 1),
                                    MLP(D, C,
                                        hidden_units=args.hidden_units,
                                        activation=args.activation))
        if args.affine: transforms.append(ConditionalAffineCouplingBijection(coupling_net=net, context_net=context_net, scale_fn=scale_fn(args.scale_fn), num_condition=I))
        else:           transforms.append(ConditionalAdditiveCouplingBijection(coupling_net=net, context_net=context_net, num_condition=I))
        if args.actnorm: transforms.append(ActNormBijection(D))
        if args.permutation == 'reverse':   transforms.append(Reverse(D))
        elif args.permutation == 'shuffle': transforms.append(Shuffle(D))
    transforms.pop()
    decoder = ConditionalFlow(base_dist=StandardNormal((D,)), transforms=transforms).to(args.device)

# Flow
transforms = []
for _ in range(args.num_flows):
    net = nn.Sequential(MLP(I, P*O,
                            hidden_units=args.hidden_units,
                            activation=args.activation),
                        ElementwiseParams(P))
    if args.affine: transforms.append(AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn), num_condition=I))
train_loader, test_loader = get_data(args)

###################
## Specify model ##
###################

D = 2 # Number of data dimensions
P = 2 if args.affine else 1 # Number of elementwise parameters

transforms = []
for _ in range(args.num_flows):
    net = nn.Sequential(MLP(1, P,
                            hidden_units=args.hidden_units,
                            activation=args.activation),
                        ElementwiseParams(P))
    if args.affine: transforms.append(AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn)))
    else:           transforms.append(AdditiveCouplingBijection(net))
    if args.actnorm: transforms.append(ActNormBijection(D))
    transforms.append(Reverse(D))
transforms.pop()


model = Flow(base_dist=StandardNormal((2,)),
             transforms=transforms).to(args.device)

#######################
## Specify optimizer ##
#######################

if args.optimizer == 'adam':
    optimizer = Adam(model.parameters(), lr=args.lr)
Beispiel #14
0
model = Flow(base_dist=ConvNormal2d((16,7,7)),
             transforms=[
                 UniformDequantization(num_bits=8),
                 # Augment(StandardUniform((1, 28, 28)), x_size=1),
                 # AffineCouplingBijection(net(2)), ActNormBijection2d(2), Conv1x1(2),
                 # AffineCouplingBijection(net(2)), ActNormBijection2d(2), Conv1x1(2),
                 # Squeeze2d(), Slice(StandardNormal((4, 14, 14)), num_keep=4),
                 # AffineCouplingBijection(net(4)), ActNormBijection2d(4), Conv1x1(4),
                 # AffineCouplingBijection(net(4)), ActNormBijection2d(4), Conv1x1(4),
                 # Squeeze2d(), Slice(StandardNormal((8, 7, 7)), num_keep=8),
                 # AffineCouplingBijection(net(8)), ActNormBijection2d(8), Conv1x1(8),
                 # AffineCouplingBijection(net(8)), ActNormBijection2d(8), Conv1x1(8),
                 #Logit(0.05),
                 ScalarAffineBijection(shift=-0.5),
                 Squeeze2d(),
                 ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")),
                 Squeeze2d(),
                 ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")),
Beispiel #15
0
                       dropout=0.2), ElementwiseParams2d(2 + k * 3))


#model = Flow(base_dist=StandardNormal((16,7,7)),
model = Flow(
    base_dist=ConvNormal2d((16, 7, 7)),
    transforms=[
        UniformDequantization(num_bits=8),
        #Logit(),
        ScalarAffineBijection(shift=-0.5),
        Squeeze2d(),
        ActNormBijection2d(4),
        Conv1x1(4),
        LogisticMixtureAffineCouplingBijection(net(4),
                                               num_mixtures=k,
                                               scale_fn=scale_fn("tanh_exp")),
        ActNormBijection2d(4),
        Conv1x1(4),
        LogisticMixtureAffineCouplingBijection(net(4),
                                               num_mixtures=k,
                                               scale_fn=scale_fn("tanh_exp")),
        Squeeze2d(),
        ActNormBijection2d(16),
        Conv1x1(16),
        LogisticMixtureAffineCouplingBijection(net(16),
                                               num_mixtures=k,
                                               scale_fn=scale_fn("tanh_exp")),
        ActNormBijection2d(16),
        Conv1x1(16),
        LogisticMixtureAffineCouplingBijection(net(16),
                                               num_mixtures=k,
Beispiel #16
0
    def __init__(self, num_flows, actnorm, affine, scale_fn_str, hidden_units,
                 activation, range_flow, augment_size, base_dist):

        D = 2  # Number of data dimensions

        if base_dist == "uniform":
            classifier = MLP(D,
                             D // 2,
                             hidden_units=hidden_units,
                             activation=activation,
                             out_lambda=lambda x: x.view(-1))
            transforms = [
                ElementAbsSurjection(classifier=classifier),
                ShiftBijection(shift=torch.tensor([[0.0, 4.0]])),
                ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 8]]))
            ]
            base = StandardUniform((D, ))

        else:
            A = D + augment_size  # Number of augmented data dimensions
            P = 2 if affine else 1  # Number of elementwise parameters

            # initialize flow with either augmentation or Abs surjection
            if augment_size > 0:
                assert augment_size % 2 == 0
                transforms = [
                    Augment(StandardNormal((augment_size, )), x_size=D)
                ]

            else:
                transforms = [SimpleAbsSurjection()]
                if range_flow == 'logit':
                    transforms += [
                        ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 4]])),
                        Logit()
                    ]
                elif range_flow == 'softplus':
                    transforms += [SoftplusInverse()]

            # apply coupling layer flows
            for _ in range(num_flows):
                net = nn.Sequential(
                    MLP(A // 2,
                        P * A // 2,
                        hidden_units=hidden_units,
                        activation=activation), ElementwiseParams(P))

                if affine:
                    transforms.append(
                        AffineCouplingBijection(
                            net, scale_fn=scale_fn(scale_fn_str)))
                else:
                    transforms.append(AdditiveCouplingBijection(net))
                if actnorm: transforms.append(ActNormBijection(D))
                transforms.append(Reverse(A))

            transforms.pop()
            base = StandardNormal((A, ))

        super(UnconditionalFlow, self).__init__(base_dist=base,
                                                transforms=transforms)
Beispiel #17
0
assert args.augdim % 2 == 0

D = 2  # Number of data dimensions
A = 2 + args.augdim  # Number of augmented data dimensions
P = 2 if args.affine else 1  # Number of elementwise parameters

transforms = [Augment(StandardNormal((args.augdim, )), x_size=D)]
for _ in range(args.num_flows):
    net = nn.Sequential(
        MLP(A // 2,
            P * A // 2,
            hidden_units=args.hidden_units,
            activation=args.activation), ElementwiseParams(P))
    if args.affine:
        transforms.append(
            AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn)))
    else:
        transforms.append(AdditiveCouplingBijection(net))
    if args.actnorm: transforms.append(ActNormBijection(D))
    transforms.append(Reverse(A))
transforms.pop()

model = Flow(base_dist=StandardNormal((A, )),
             transforms=transforms).to(args.device)

#######################
## Specify optimizer ##
#######################

if args.optimizer == 'adam':
    optimizer = Adam(model.parameters(), lr=args.lr)