Esempio n. 1
0
model = Flow(base_dist=StandardNormal((24, 8, 8)),
             transforms=[
                 UniformDequantization(num_bits=8),
                 Augment(StandardUniform((3, 32, 32)), x_size=3),
                 AffineCouplingBijection(net(6)),
                 ActNormBijection2d(6),
                 Conv1x1(6),
                 AffineCouplingBijection(net(6)),
                 ActNormBijection2d(6),
                 Conv1x1(6),
                 AffineCouplingBijection(net(6)),
                 ActNormBijection2d(6),
                 Conv1x1(6),
                 AffineCouplingBijection(net(6)),
                 ActNormBijection2d(6),
                 Conv1x1(6),
                 Squeeze2d(),
                 Slice(StandardNormal((12, 16, 16)), num_keep=12),
                 AffineCouplingBijection(net(12)),
                 ActNormBijection2d(12),
                 Conv1x1(12),
                 AffineCouplingBijection(net(12)),
                 ActNormBijection2d(12),
                 Conv1x1(12),
                 AffineCouplingBijection(net(12)),
                 ActNormBijection2d(12),
                 Conv1x1(12),
                 AffineCouplingBijection(net(12)),
                 ActNormBijection2d(12),
                 Conv1x1(12),
                 Squeeze2d(),
                 Slice(StandardNormal((24, 8, 8)), num_keep=24),
                 AffineCouplingBijection(net(24)),
                 ActNormBijection2d(24),
                 Conv1x1(24),
                 AffineCouplingBijection(net(24)),
                 ActNormBijection2d(24),
                 Conv1x1(24),
                 AffineCouplingBijection(net(24)),
                 ActNormBijection2d(24),
                 Conv1x1(24),
                 AffineCouplingBijection(net(24)),
                 ActNormBijection2d(24),
                 Conv1x1(24),
             ]).to(device)
            AffineCouplingBijection(net,
                                    split_dim=2,
                                    scale_fn=scale_fn(args.scale_fn)))
    else:
        transforms.append(AdditiveCouplingBijection(net, split_dim=2))
    if args.stochperm: transforms.append(StochasticPermutation(dim=2))
    else: transforms.append(Shuffle(L, dim=2))
    return transforms


for _ in range(args.num_flows):
    if args.dimwise: transforms = dimwise(transforms)
    if args.lenwise: transforms = lenwise(transforms)
    if args.actnorm: transforms.append(ActNormBijection1d(2))

model = Flow(base_dist=StandardNormal((D, L)),
             transforms=transforms).to(args.device)
if not args.train:
    state_dict = torch.load('models/{}.pt'.format(run_name))
    model.load_state_dict(state_dict)

#######################
## Specify optimizer ##
#######################

if args.optimizer == 'adam':
    optimizer = Adam(model.parameters(), lr=args.lr)
elif args.optimizer == 'adamax':
    optimizer = Adamax(model.parameters(), lr=args.lr)

if args.warmup is not None:
    scheduler_iter = LinearWarmupScheduler(optimizer, total_epoch=args.warmup)
Esempio n. 3
0
encoder = ConditionalNormal(MLP(784, 2*latent_size,
                                hidden_units=[512, 256],
                                activation=None,
                                in_lambda=lambda x: 2 * x.view(x.shape[0], 784).float() - 1))
decoder = ConditionalNormal(MLP(latent_size, 784 * 2,
                                hidden_units=[256, 512],
                                activation=None,
                                out_lambda=lambda x: x.view(x.shape[0], 2, 28, 28)), split_dim=1)
# decoder = ConditionalBernoulli(MLP(latent_size, 784,
#                                    hidden_units=[256, 512],
#                                    activation=None,
#                                    out_lambda=lambda x: x.view(x.shape[0], 1, 28, 28)))

model = Flow(base_dist=StandardNormal((latent_size,)),
             transforms=[
                 UniformDequantization(num_bits=8),
                 VAE(encoder=encoder, decoder=decoder)
             ]).to(device)

print(model)

###########
## Optim ##
###########

optimizer = Adam(model.parameters(), lr=1e-3)

###########
## Train ##
###########
##################

train_loader, test_loader = get_data(args)

###################
## Specify model ##
###################

classifier = MLP(2, 1,
                 hidden_units=args.hidden_units,
                 activation=args.activation,
                 out_lambda=lambda x: x.view(-1))

model = Flow(base_dist=StandardUniform((2,)),
             transforms=[
                ElementAbsSurjection(classifier=classifier),
                ShiftBijection(shift=torch.tensor([[0.0, 4.0]])),
                ScaleBijection(scale=torch.tensor([[1/4, 1/8]]))
                        ]).to(args.device)

#######################
## Specify optimizer ##
#######################

if args.optimizer == 'adam':
    optimizer = Adam(model.parameters(), lr=args.lr)
elif args.optimizer == 'adamax':
    optimizer = Adamax(model.parameters(), lr=args.lr)

##############
## Training ##
##############
Esempio n. 5
0
for _ in range(args.num_flows):
    net = nn.Sequential(
        MLP(A // 2,
            P * A // 2,
            hidden_units=args.hidden_units,
            activation=args.activation), ElementwiseParams(P))
    if args.affine:
        transforms.append(
            AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn)))
    else:
        transforms.append(AdditiveCouplingBijection(net))
    if args.actnorm: transforms.append(ActNormBijection(D))
    transforms.append(Reverse(A))
transforms.pop()

model = Flow(base_dist=StandardNormal((A, )),
             transforms=transforms).to(args.device)

#######################
## Specify optimizer ##
#######################

if args.optimizer == 'adam':
    optimizer = Adam(model.parameters(), lr=args.lr)
elif args.optimizer == 'adamax':
    optimizer = Adamax(model.parameters(), lr=args.lr)

##############
## Training ##
##############

print('Training...')
Esempio n. 6
0
## Model ##
###########


def net():
    return nn.Sequential(nn.Linear(1, 200), nn.ReLU(), nn.Linear(200, 100),
                         nn.ReLU(), nn.Linear(100, 2), ElementwiseParams(2))


model = Flow(base_dist=StandardNormal((2, )),
             transforms=[
                 AffineCouplingBijection(net()),
                 ActNormBijection(2),
                 Reverse(2),
                 AffineCouplingBijection(net()),
                 ActNormBijection(2),
                 Reverse(2),
                 AffineCouplingBijection(net()),
                 ActNormBijection(2),
                 Reverse(2),
                 AffineCouplingBijection(net()),
                 ActNormBijection(2),
             ])

###########
## Optim ##
###########

optimizer = Adam(model.parameters(), lr=1e-3)

###########
## Train ##
Esempio n. 7
0
    net = nn.Sequential(MLP(I, P*O,
                            hidden_units=args.hidden_units,
                            activation=args.activation),
                        ElementwiseParams(P))
    if args.affine: transforms.append(AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn), num_condition=I))
    else:           transforms.append(AdditiveCouplingBijection(net, num_condition=I))
    if args.actnorm: transforms.append(ActNormBijection(D))
    if args.permutation == 'reverse':   transforms.append(Reverse(D))
    elif args.permutation == 'shuffle': transforms.append(Shuffle(D))
transforms.pop()
if args.num_bits is not None:
    transforms.append(Sigmoid())
    transforms.append(VariationalQuantization(decoder, num_bits=args.num_bits))


pi = Flow(base_dist=target,
          transforms=transforms).to(args.device)

p = StandardNormal(shape).to(args.device)

#######################
## Specify optimizer ##
#######################

if args.optimizer == 'adam':
    optimizer = Adam(pi.parameters(), lr=args.lr)
elif args.optimizer == 'adamax':
    optimizer = Adamax(pi.parameters(), lr=args.lr)

##############
## Training ##
##############
Esempio n. 8
0
            hidden_units=args.hidden_units,
            activation=args.activation), ElementwiseParams(P))
    if args.affine:
        transforms.append(
            AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn)))
    else:
        transforms.append(AdditiveCouplingBijection(net))
    if args.actnorm: transforms.append(ActNormBijection(D))
    if args.permutation == 'reverse': transforms.append(Reverse(D))
    elif args.permutation == 'shuffle': transforms.append(Shuffle(D))
transforms.pop()
if args.num_bits is not None:
    transforms.append(Sigmoid())
    transforms.append(VariationalQuantization(decoder, num_bits=args.num_bits))

pi = Flow(base_dist=target, transforms=transforms).to(args.device)

p = StandardNormal(shape).to(args.device)

##############
## Training ##
##############

state_dict = torch.load(path_check)
pi.load_state_dict(state_dict)

print('Running MCMC...')
samples, rate = metropolis_hastings(
    pi=pi,
    num_dims=args.num_dims,
    num_chains=eval_args.num_chains,
Esempio n. 9
0
#model = Flow(base_dist=StandardNormal((16,7,7)),
model = Flow(
    base_dist=ConvNormal2d((16, 7, 7)),
    transforms=[
        UniformDequantization(num_bits=8),
        #Logit(),
        ScalarAffineBijection(shift=-0.5),
        Squeeze2d(),
        ActNormBijection2d(4),
        Conv1x1(4),
        LogisticMixtureAffineCouplingBijection(net(4),
                                               num_mixtures=k,
                                               scale_fn=scale_fn("tanh_exp")),
        ActNormBijection2d(4),
        Conv1x1(4),
        LogisticMixtureAffineCouplingBijection(net(4),
                                               num_mixtures=k,
                                               scale_fn=scale_fn("tanh_exp")),
        Squeeze2d(),
        ActNormBijection2d(16),
        Conv1x1(16),
        LogisticMixtureAffineCouplingBijection(net(16),
                                               num_mixtures=k,
                                               scale_fn=scale_fn("tanh_exp")),
        ActNormBijection2d(16),
        Conv1x1(16),
        LogisticMixtureAffineCouplingBijection(net(16),
                                               num_mixtures=k,
                                               scale_fn=scale_fn("tanh_exp")),
    ]).to(device)
Esempio n. 10
0
model = Flow(base_dist=ConvNormal2d((16,7,7)),
             transforms=[
                 UniformDequantization(num_bits=8),
                 # Augment(StandardUniform((1, 28, 28)), x_size=1),
                 # AffineCouplingBijection(net(2)), ActNormBijection2d(2), Conv1x1(2),
                 # AffineCouplingBijection(net(2)), ActNormBijection2d(2), Conv1x1(2),
                 # Squeeze2d(), Slice(StandardNormal((4, 14, 14)), num_keep=4),
                 # AffineCouplingBijection(net(4)), ActNormBijection2d(4), Conv1x1(4),
                 # AffineCouplingBijection(net(4)), ActNormBijection2d(4), Conv1x1(4),
                 # Squeeze2d(), Slice(StandardNormal((8, 7, 7)), num_keep=8),
                 # AffineCouplingBijection(net(8)), ActNormBijection2d(8), Conv1x1(8),
                 # AffineCouplingBijection(net(8)), ActNormBijection2d(8), Conv1x1(8),
                 #Logit(0.05),
                 ScalarAffineBijection(shift=-0.5),
                 Squeeze2d(),
                 ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")),
                 Squeeze2d(),
                 ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")),
                 ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")),
             ]).to(device)
Esempio n. 11
0
                                in_lambda=lambda x: 2 * x.view(x.shape[0], 784).float() - 1))
decoder = ConditionalBernoulli(MLP(latent_sizes[0], 784,
                                   hidden_units=[512,256],
                                   activation='relu',
                                   out_lambda=lambda x: x.view(x.shape[0], 1, 28, 28)))

encoder2 = ConditionalNormal(MLP(latent_sizes[0], 2*latent_sizes[1],
                                 hidden_units=[256,128],
                                 activation='relu'))
decoder2 = ConditionalNormal(MLP(latent_sizes[1], 2*latent_sizes[0],
                                 hidden_units=[256,128],
                                 activation='relu'))

model = Flow(base_dist=StandardNormal((latent_sizes[-1],)),
             transforms=[
                VAE(encoder=encoder, decoder=decoder),
                VAE(encoder=encoder2, decoder=decoder2),
             ]).to(device)

###########
## Optim ##
###########

optimizer = Adam(model.parameters(), lr=1e-3)

###########
## Train ##
###########

print('Training...')
for epoch in range(20):
Esempio n. 12
0
        Squeeze2d(4),
        Slice(StandardNormal((channels * 2, items)), num_keep=channels * 2),
    ]


model = Flow(
    base_dist=StandardNormal((base_channels * (2**5), n_items // (4**4))),
    transforms=[
        UniformDequantization(num_bits=8),
        Augment(StandardUniform((base_channels * 1, n_items)),
                x_size=base_channels),
        *reduction_layer(base_channels * (2**1), n_items // (4**1)),
        *reduction_layer(base_channels * (2**2), n_items // (4**2)),
        *reduction_layer(base_channels * (2**3), n_items // (4**3)),
        *reduction_layer(base_channels * (2**4), n_items // (4**4)),
        # *reduction_layer(base_channels*(2**5), n_items//(4**4)),
        *perm_norm_bi(base_channels * (2**5))

        # AffineCouplingBijection(net(base_channels*2)), ActNormBijection2d(base_channels*2), Conv1x1(base_channels*2),
        # AffineCouplingBijection(net(base_channels*2)), ActNormBijection2d(base_channels*2), Conv1x1(base_channels*2),
        # AffineCouplingBijection(net(base_channels*2)), ActNormBijection2d(base_channels*2), Conv1x1(base_channels*2),
        # Squeeze2d(), Slice(StandardNormal((base_channels*2, n_items//4)), num_keep=base_channels*2),
        # AffineCouplingBijection(net(base_channels*2)), ActNormBijection2d(base_channels*2), Conv1x1(base_channels*2),
        # AffineCouplingBijection(net(base_channels*2)), ActNormBijection2d(base_channels*2), Conv1x1(base_channels*2),
        # AffineCouplingBijection(net(base_channels*2)), ActNormBijection2d(base_channels*2), Conv1x1(base_channels*2),
    ]).to(device)

x = next(iter(train_loader))

x = x['X']
x = x.unsqueeze(1)