def train_model(model, train_loader, test_loader, device, args): model = model.to(device) loss_list = [] optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999)) for epoch in range(1, args.epochs + 1): train(model, train_loader, optimizer, device, epoch, args) test_rec_loss, test_reg_loss, test_loss = test(model, test_loader, device, args) loss_list.append([test_rec_loss, test_reg_loss, test_loss]) if epoch % args.landmark_interval == 0: evaluation.interpolation_2d(model, test_loader, device, epoch, args, prefix='wae') evaluation.sampling(model, device, epoch, args, prior=None, prefix='wae') evaluation.reconstruction(model, test_loader, device, epoch, args, prefix='wae') return loss_list
def train_model(model, prior, train_loader, test_loader, device, args): model = model.to(device) prior = prior.to(device) loss_list = [] optimizer = optim.Adam(list(model.parameters()) + list(prior.parameters()), lr=1e-4) for epoch in range(1, args.epochs + 1): train(model, prior, train_loader, optimizer, device, epoch, args) test_rec_loss, test_reg_loss, test_loss = test(model, prior, test_loader, device, args) loss_list.append([test_rec_loss, test_reg_loss, test_loss]) if epoch % args.landmark_interval == 0: evaluation.interpolation_2d(model, test_loader, device, epoch, args, prefix='vampprior') prior.eval() model.eval() x = prior() _, _, z_p_mean, z_p_logvar = model(x) print(z_p_mean.size()) evaluation.sampling(model, device, epoch, args, prior=[z_p_mean, z_p_logvar], prefix='vampprior') evaluation.reconstruction(model, test_loader, device, epoch, args, prefix='vampprior') return loss_list
model_type=args.model_type) else: args.x_dim = int(64 * 64) args.z_dim = 64 args.nc = 3 model = AE_CelebA(z_dim=args.z_dim, nc=args.nc, model_type=args.model_type) src_loaders = load_datasets(args=args) loss = wae.train_model(model, src_loaders['train'], src_loaders['val'], device, args) # conditional generation model.eval() evaluation.sampling(model, device, args.epochs, args, prefix='wae', nrow=4) # t-sne visualization if args.source_data == 'MNIST': evaluation.visualization_tsne(model, src_loaders['val'], device, args, prefix='wae') else: evaluation.visualization_tsne2(model, src_loaders['val'], device, args, prefix='wae')
loss = prae.train_model(model, prior, src_loaders['train'], src_loaders['val'], device, args) else: loss = drae.train_model(model, prior, src_loaders['train'], src_loaders['val'], device, args) # conditional generation prior.eval() model.eval() z_p_mean, z_p_logvar = prior() prior_list = [z_p_mean, z_p_logvar] for i in range(args.K): evaluation.sampling( model, device, i + 1, args, prefix='rae', prior=[z_p_mean[i, :].unsqueeze(0), z_p_logvar[i, :].unsqueeze(0)], nrow=4) # t-sne visualization if args.source_data == 'MNIST': evaluation.visualization_tsne(model, src_loaders['val'], device, args, prefix='rae', prior=prior_list) else: evaluation.visualization_tsne2(model, src_loaders['val'],