if args.dimwise: transforms = dimwise(transforms)
    if args.lenwise: transforms = lenwise(transforms)
    if args.actnorm: transforms.append(ActNormBijection1d(2))

model = Flow(base_dist=StandardNormal((D, L)),
             transforms=transforms).to(args.device)
if not args.train:
    state_dict = torch.load('models/{}.pt'.format(run_name))
    model.load_state_dict(state_dict)

#######################
## Specify optimizer ##
#######################

if args.optimizer == 'adam':
    optimizer = Adam(model.parameters(), lr=args.lr)
elif args.optimizer == 'adamax':
    optimizer = Adamax(model.parameters(), lr=args.lr)

if args.warmup is not None:
    scheduler_iter = LinearWarmupScheduler(optimizer, total_epoch=args.warmup)
else:
    scheduler_iter = None

if args.gamma is not None:
    scheduler_epoch = ExponentialLR(optimizer, gamma=args.gamma)
else:
    scheduler_epoch = None

#####################
## Define training ##
Esempio n. 2
0
                 AffineCouplingBijection(net(24)),
                 ActNormBijection2d(24),
                 Conv1x1(24),
                 AffineCouplingBijection(net(24)),
                 ActNormBijection2d(24),
                 Conv1x1(24),
                 AffineCouplingBijection(net(24)),
                 ActNormBijection2d(24),
                 Conv1x1(24),
             ]).to(device)

###########
## Optim ##
###########

optimizer = Adam(model.parameters(), lr=1e-3)

###########
## Train ##
###########

print('Training...')
for epoch in range(10):
    l = 0.0
    for i, x in enumerate(train_loader):
        optimizer.zero_grad()
        loss = -model.log_prob(x.to(device)).sum() / (math.log(2) * x.numel())
        loss.backward()
        optimizer.step()
        l += loss.detach().cpu().item()
        print('Epoch: {}/{}, Iter: {}/{}, Bits/dim: {:.3f}'.format(
Esempio n. 3
0
if args.num_bits is not None:
    transforms.append(Sigmoid())
    transforms.append(VariationalQuantization(decoder, num_bits=args.num_bits))


pi = Flow(base_dist=target,
          transforms=transforms).to(args.device)

p = StandardNormal(shape).to(args.device)

#######################
## Specify optimizer ##
#######################

if args.optimizer == 'adam':
    optimizer = Adam(pi.parameters(), lr=args.lr)
elif args.optimizer == 'adamax':
    optimizer = Adamax(pi.parameters(), lr=args.lr)

##############
## Training ##
##############

print('Training...')
loss_sum = 0.0
for i in range(args.iter):
    z, log_p_z = p.sample_with_log_prob(args.batch_size)
    log_pi_z = pi.log_prob(z)
    KL = (log_p_z - log_pi_z).mean()
    optimizer.zero_grad()
    loss = KL