Esempio n. 1
0
          img_size=img_size,
          z_dim=z_dim,
          use_cuda=use_cuda,
          conf=conf)
tbar = tqdm(range(num_epochs))
fid = open('losses', 'w')
for epoch in tbar:
    train_rec, train_kl = vae.one_epoch(train_loader)
    valid_rec, valid_kl = vae.evaluate(validate_loader)

    vae.eval()
    bx_train, _ = iter(train_loader).next()
    bx_valid, _ = iter(validate_loader).next()
    bx_train = bx_train[:8]
    bx_valid = bx_valid[:8]
    rand_samp = vae.random_sample(num_samples=8)
    rec_train = vae.reconstruct_img(bx_train)
    rec_valid = vae.reconstruct_img(bx_valid)
    bx_train = vae.unnormalize(bx_train)
    bx_valid = vae.unnormalize(bx_valid)
    show_imgs = make_grid(torch.cat(
        (rand_samp, bx_train, rec_train, bx_valid, rec_valid), dim=0),
                          nrow=8)
    save_image(show_imgs, 'img_%d.png' % epoch)
    vae.train()

    tbar.set_description('%8.3f %8.3f %8.3f %8.3f' %
                         (train_rec, train_kl, valid_rec, valid_kl))
    fid.write('%8.3f %8.3f %8.3f %8.3f\n' %
              (train_rec, train_kl, valid_rec, valid_kl))
    fid.flush()
Esempio n. 2
0
conf['kl_factor']         = kl_factor
conf['norm_mean']         = mean
conf['norm_std']          = std
vae                       = VAE(n_channel=3,img_size=img_size,z_dim = z_dim, use_cuda = use_cuda, conf = conf)
tbar                      = tqdm(range(num_epochs))
fid                       = open('losses', 'w')
for epoch in tbar:
    train_rec, train_kl = vae.one_epoch(train_loader)
    valid_rec, valid_kl = vae.evaluate(validate_loader)

    vae.eval()
    bx_train,_ = iter(train_loader).next()
    bx_valid,_ = iter(validate_loader).next()
    bx_train   = bx_train[:8]
    bx_valid   = bx_valid[:8]
    rand_samp  = vae.unnormalize(vae.random_sample(num_samples = 8))
    rec_train  = vae.unnormalize(vae.reconstruct_img(bx_train))
    rec_valid  = vae.unnormalize(vae.reconstruct_img(bx_valid))
    bx_train   = vae.unnormalize(bx_train)
    bx_valid   = vae.unnormalize(bx_valid)
    show_imgs  = make_grid(torch.cat((rand_samp, bx_train, rec_train, bx_valid, rec_valid), dim = 0), nrow=8)
    save_image(show_imgs,'img_%d.png' % epoch)
    vae.train()

    tbar.set_description('%8.3f %8.3f %8.3f %8.3f' % (train_rec ,  train_kl , valid_rec, valid_kl))
    fid.write('%8.3f %8.3f %8.3f %8.3f\n' % (train_rec ,  train_kl , valid_rec, valid_kl))
    fid.flush()

fid.close()

vae.eval()