Ejemplo n.º 1
0
    def test_bijection_is_well_behaved(self):
        batch_size = 10
        shape = [2, 3, 4]
        x = torch.randn(batch_size, *shape)
        bijections = [
            Permute(torch.tensor([0, 1]), 1),
            Permute([0, 1], 1),
            Permute([1, 0], 1),
            Permute([0, 1, 2], 2),
            Permute([2, 1, 0], 2),
            Permute([0, 1, 2, 3], 3),
            Permute([3, 2, 1, 0], 3),
            Permute([0, 2, 1], 2),
            Permute([2, 0, 1], 2),
            Permute([1, 0, 2], 2),
            Permute([1, 2, 0], 2),
            Reverse(dim_size=2, dim=1),
            Reverse(dim_size=3, dim=2),
            Reverse(dim_size=4, dim=3),
            Shuffle(dim_size=2, dim=1),
            Shuffle(dim_size=3, dim=2),
            Shuffle(dim_size=4, dim=3),
        ]

        for bijection in bijections:
            with self.subTest(bijection=bijection):
                self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape))
Ejemplo n.º 2
0
def dimwise(transforms):
    net = nn.Sequential(DenseTransformer(d_input=D//2, d_output=P*D//2, d_model=args.d_model, nhead=args.nhead,
                                         num_layers=args.num_layers, dim_feedforward=4*args.d_model,
                                         dropout=args.dropout, activation=args.activation,
                                         checkpoint_blocks=args.checkpoint_blocks),
                        ElementwiseParams1d(P))
    if args.affine: transforms.append(AffineCouplingBijection(net, split_dim=1, scale_fn=scale_fn(args.scale_fn)))
    else:           transforms.append(AdditiveCouplingBijection(net, split_dim=1))
    transforms.append(Reverse(D, dim=1))
    return transforms
Ejemplo n.º 3
0
    def __init__(self, num_flows, actnorm, affine, scale_fn_str, hidden_units,
                 activation, range_flow, augment_size, base_dist, cond_size):

        D = 2  # Number of data dimensions
        A = D + augment_size  # Number of augmented data dimensions
        P = 2 if affine else 1  # Number of elementwise parameters

        # initialize context. Only upsample context in ContextInit if latent shape doesn't change during the flow.
        context_init = MLP(input_size=cond_size,
                           output_size=D,
                           hidden_units=hidden_units,
                           activation=activation)

        # initialize flow with either augmentation or Abs surjection
        if augment_size > 0:
            assert augment_size % 2 == 0
            transforms = [Augment(StandardNormal((augment_size, )), x_size=D)]

        else:
            transforms = []
            transforms = [SimpleAbsSurjection()]
            if range_flow == 'logit':
                transforms += [
                    ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 4]])),
                    Logit()
                ]
            elif range_flow == 'softplus':
                transforms += [SoftplusInverse()]

        # apply coupling layer flows
        for _ in range(num_flows):
            net = nn.Sequential(
                MLP(A // 2 + D,
                    P * A // 2,
                    hidden_units=hidden_units,
                    activation=activation), ElementwiseParams(P))
            if affine:
                transforms.append(
                    ConditionalAffineCouplingBijection(
                        net, scale_fn=scale_fn(scale_fn_str)))
            else:
                transforms.append(ConditionalAdditiveCouplingBijection(net))
            if actnorm: transforms.append(ActNormBijection(D))
            transforms.append(Reverse(A))

        transforms.pop()

        if base_dist == "uniform":
            base = StandardUniform((A, ))
        else:
            base = StandardNormal((A, ))

        super(SRFlow, self).__init__(base_dist=base,
                                     transforms=transforms,
                                     context_init=context_init)
Ejemplo n.º 4
0
P = 2 if args.affine else 1  # Number of elementwise parameters

transforms = [Augment(StandardNormal((args.augdim, )), x_size=D)]
for _ in range(args.num_flows):
    net = nn.Sequential(
        MLP(A // 2,
            P * A // 2,
            hidden_units=args.hidden_units,
            activation=args.activation), ElementwiseParams(P))
    if args.affine:
        transforms.append(
            AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn)))
    else:
        transforms.append(AdditiveCouplingBijection(net))
    if args.actnorm: transforms.append(ActNormBijection(D))
    transforms.append(Reverse(A))
transforms.pop()

model = Flow(base_dist=StandardNormal((A, )),
             transforms=transforms).to(args.device)

#######################
## Specify optimizer ##
#######################

if args.optimizer == 'adam':
    optimizer = Adam(model.parameters(), lr=args.lr)
elif args.optimizer == 'adamax':
    optimizer = Adamax(model.parameters(), lr=args.lr)

##############
Ejemplo n.º 5
0
# Decoder
if args.num_bits is not None:
    transforms = [Logit()]
    for _ in range(args.num_flows):
        net = nn.Sequential(MLP(C+I, P*O,
                                hidden_units=args.hidden_units,
                                activation=args.activation),
                            ElementwiseParams(P))
        context_net = nn.Sequential(LambdaLayer(lambda x: 2*x.float()/(2**args.num_bits-1) - 1),
                                    MLP(D, C,
                                        hidden_units=args.hidden_units,
                                        activation=args.activation))
        if args.affine: transforms.append(ConditionalAffineCouplingBijection(coupling_net=net, context_net=context_net, scale_fn=scale_fn(args.scale_fn), num_condition=I))
        else:           transforms.append(ConditionalAdditiveCouplingBijection(coupling_net=net, context_net=context_net, num_condition=I))
        if args.actnorm: transforms.append(ActNormBijection(D))
        if args.permutation == 'reverse':   transforms.append(Reverse(D))
        elif args.permutation == 'shuffle': transforms.append(Shuffle(D))
    transforms.pop()
    decoder = ConditionalFlow(base_dist=StandardNormal((D,)), transforms=transforms).to(args.device)

# Flow
transforms = []
for _ in range(args.num_flows):
    net = nn.Sequential(MLP(I, P*O,
                            hidden_units=args.hidden_units,
                            activation=args.activation),
                        ElementwiseParams(P))
    if args.affine: transforms.append(AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn), num_condition=I))
    else:           transforms.append(AdditiveCouplingBijection(net, num_condition=I))
    if args.actnorm: transforms.append(ActNormBijection(D))
    if args.permutation == 'reverse':   transforms.append(Reverse(D))
Ejemplo n.º 6
0
###########
## Model ##
###########


def net():
    return nn.Sequential(nn.Linear(1, 200), nn.ReLU(), nn.Linear(200, 100),
                         nn.ReLU(), nn.Linear(100, 2), ElementwiseParams(2))


model = Flow(base_dist=StandardNormal((2, )),
             transforms=[
                 AffineCouplingBijection(net()),
                 ActNormBijection(2),
                 Reverse(2),
                 AffineCouplingBijection(net()),
                 ActNormBijection(2),
                 Reverse(2),
                 AffineCouplingBijection(net()),
                 ActNormBijection(2),
                 Reverse(2),
                 AffineCouplingBijection(net()),
                 ActNormBijection(2),
             ])

###########
## Optim ##
###########

optimizer = Adam(model.parameters(), lr=1e-3)
Ejemplo n.º 7
0
    def __init__(self, num_flows, actnorm, affine, scale_fn_str, hidden_units,
                 activation, range_flow, augment_size, base_dist):

        D = 2  # Number of data dimensions

        if base_dist == "uniform":
            classifier = MLP(D,
                             D // 2,
                             hidden_units=hidden_units,
                             activation=activation,
                             out_lambda=lambda x: x.view(-1))
            transforms = [
                ElementAbsSurjection(classifier=classifier),
                ShiftBijection(shift=torch.tensor([[0.0, 4.0]])),
                ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 8]]))
            ]
            base = StandardUniform((D, ))

        else:
            A = D + augment_size  # Number of augmented data dimensions
            P = 2 if affine else 1  # Number of elementwise parameters

            # initialize flow with either augmentation or Abs surjection
            if augment_size > 0:
                assert augment_size % 2 == 0
                transforms = [
                    Augment(StandardNormal((augment_size, )), x_size=D)
                ]

            else:
                transforms = [SimpleAbsSurjection()]
                if range_flow == 'logit':
                    transforms += [
                        ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 4]])),
                        Logit()
                    ]
                elif range_flow == 'softplus':
                    transforms += [SoftplusInverse()]

            # apply coupling layer flows
            for _ in range(num_flows):
                net = nn.Sequential(
                    MLP(A // 2,
                        P * A // 2,
                        hidden_units=hidden_units,
                        activation=activation), ElementwiseParams(P))

                if affine:
                    transforms.append(
                        AffineCouplingBijection(
                            net, scale_fn=scale_fn(scale_fn_str)))
                else:
                    transforms.append(AdditiveCouplingBijection(net))
                if actnorm: transforms.append(ActNormBijection(D))
                transforms.append(Reverse(A))

            transforms.pop()
            base = StandardNormal((A, ))

        super(UnconditionalFlow, self).__init__(base_dist=base,
                                                transforms=transforms)