Beispiel #1
0
def train(epoch):
    train_loss = 0
    pbar = tqdm(celeb_loader)
    num_mini_batches = 0
    for (data, _) in pbar:
        if cmd_args.ctx == 'gpu':
            data = data.cuda()

        optimizer.zero_grad()
        fz = lambda z: binary_cross_entropy(decoder(z).view(data.shape[0], -1), data.view(data.shape[0], -1)) + log_score(data, z)[0]
        xi = torch.Tensor( data.shape[0], cmd_args.latent_dim ).normal_()
        if cmd_args.ctx == 'gpu':
            xi = xi.cuda()
        
        z_init = encoder(data, xi)
        best_z = optimize_variable(z_init, fz, EuclideanDist, nsteps = cmd_args.unroll_steps)
                
        optimizer.zero_grad()
        loss = fz(best_z)
        loss.backward()

        prior_z = torch.Tensor(data.shape[0], cmd_args.latent_dim).normal_(0, 1)
        if cmd_args.ctx == 'gpu':
            prior_z = prior_z.cuda()
        cur_loss = loss.item() + torch.mean( torch.exp( nu(data, prior_z).detach() ) ).item()

        train_loss += cur_loss
        optimizer.step()

        pbar.set_description('minibatch loss: %.4f' % cur_loss)
        num_mini_batches += 1
    msg = 'train epoch %d, average loss %.4f' % (epoch, train_loss / num_mini_batches)
    print(msg)
Beispiel #2
0
def do_vis(epoch):
    for i, (data, _) in tqdm(enumerate(test_loader)):
        t = data[:64].view(64, 3, 64, 64)
        save_image(t.cpu(), '%s/dataset-%d-' % (cmd_args.save_dir, i) + str(epoch) + '.png', nrow=8)
        if cmd_args.ctx == 'gpu':
            data = data.cuda()
        bak_nu_dict = nu.state_dict()
        bak_opt_dict = opt_nu.state_dict()

        fz = lambda z: binary_cross_entropy(decoder(z).view(data.shape[0], -1), data.view(data.shape[0], -1)) + log_score(data, z)[0]
        xi = torch.Tensor( data.shape[0], cmd_args.latent_dim ).normal_()
        if cmd_args.ctx == 'gpu':
            xi = xi.cuda()
        z_init = encoder(data, xi)

        if cmd_args.unroll_test:
            best_z = optimize_variable(z_init, fz, EuclideanDist, nsteps = cmd_args.unroll_steps)
        else:
            best_z = z_init

        recon_batch = decoder(best_z)

        n = min(data.size(0), 8)
        comparison = torch.cat([data[:n],
                                recon_batch.view(cmd_args.batch_size, 3, 64, 64)[:n]])
        save_image(comparison.data.cpu(),
                    '%s/recon-%d-' % (cmd_args.save_dir, i) + str(epoch) + '.png', nrow=n)
        z = torch.Tensor(64, cmd_args.latent_dim).normal_(0, 1)
        if cmd_args.ctx == 'gpu':
            z = z.cuda()
        sample = decoder(z).view(64, 3, 64, 64)
        save_image(sample.data.cpu(),
                    '%s/prior-%d-' % (cmd_args.save_dir, i) + str(epoch) + '.png', nrow=8)

        z = z_init
        sample = decoder(z)[:64].view(64, 3, 64, 64)
        save_image(sample.data.cpu(),
                    '%s/posterior-%d-' % (cmd_args.save_dir, i) + str(epoch) + '.png', nrow=8)
        nu.load_state_dict(bak_nu_dict)
        opt_nu.load_state_dict(bak_opt_dict)
        if i + 1 >= cmd_args.vis_num:
            break
Beispiel #3
0
def test(epoch):
    test_loss = 0
    encoder.eval()
    for i, (data, _) in tqdm(enumerate(test_loader)):
        bak_nu_dict = nu.state_dict()
        bak_opt_dict = opt_nu.state_dict()

        data = convert_data(data)

        fz = lambda z: binary_cross_entropy(decoder(z), data) + log_score(
            data, z)[0]
        z_init, mu, logvar = encoder(data)

        if cmd_args.unroll_test:
            best_z = optimize_variable(z_init,
                                       fz,
                                       EuclideanDist,
                                       nsteps=cmd_args.unroll_steps)
        else:
            best_z = z_init

        loss = fz(best_z) + kl_loss(mu, logvar)

        nu.load_state_dict(bak_nu_dict)
        opt_nu.load_state_dict(bak_opt_dict)
        recon_batch = decoder(best_z)
        test_loss += loss.item() * data.shape[0]
        if i == 0:
            n = min(data.size(0), 8)
            comparison = torch.cat([
                data[:n].view(-1, 1, cmd_args.img_size, cmd_args.img_size),
                recon_batch.view(cmd_args.batch_size, 1, cmd_args.img_size,
                                 cmd_args.img_size)[:n]
            ])
            save_image(comparison.data.cpu(),
                       '%s/gauss_fenchel_reconstruction_' % cmd_args.save_dir +
                       str(epoch) + '.png',
                       nrow=n)
    test_loss /= len(test_loader.dataset)
    msg = 'test epoch %d, average loss %.4f' % (epoch, test_loss)
    print(msg)
    return test_loss
Beispiel #4
0
def train(epoch):
    encoder.train()
    train_loss = 0
    pbar = tqdm(train_loader)
    num_mini_batches = 0
    loss_list = []
    for (data, _) in pbar:
        data = convert_data(data)

        optimizer.zero_grad()
        # bak_nu_dict = nu.state_dict()
        fz = lambda z: binary_cross_entropy(decoder(z), data) + log_score(
            data, z, update=True)[0]

        z_init, mu, logvar = encoder(data)
        best_z = optimize_variable(z_init,
                                   fz,
                                   EuclideanDist,
                                   nsteps=cmd_args.unroll_steps,
                                   eps=0)

        kl = kl_loss(mu, logvar)
        obj = fz(best_z)
        if num_mini_batches % 1 == 0:
            loss = kl + obj
            loss.backward()
            optimizer.step()

        recon_loss = binary_cross_entropy(decoder(best_z), data)
        vae_loss = kl.item() + recon_loss.item()
        train_loss += loss.item()

        pbar.set_description('vae loss: %.4f, recon: %.4f, fenchel_obj: %.4f' %
                             (vae_loss, recon_loss.item(), obj.item()))
        loss_list.append(loss.item())
        # nu.load_state_dict(bak_nu_dict)
        #        for _ in range(1):
        #            log_score(data, best_z)
        num_mini_batches += 1
    msg = 'train epoch %d, average loss %.4f' % (epoch, np.mean(loss_list))
    print(msg)
Beispiel #5
0
def test(epoch):
    test_loss = 0
    for i, (data, _) in enumerate(test_loader):
        bak_nu_dict = nu.state_dict()
        bak_opt_dict = opt_nu.state_dict()
        
        if cmd_args.ctx == 'gpu':
            data = data.cuda()
        data = Variable(data)
        with torch.no_grad():
            prior_z = torch.Tensor(data.shape[0], cmd_args.latent_dim).normal_(0, 1)
            if cmd_args.ctx == 'gpu':
                prior_z = prior_z.cuda()
            cur_loss = torch.mean( torch.exp( nu(data, prior_z) ) ).item()

        fz = lambda z: binary_cross_entropy(decoder(z).view(data.shape[0], -1), data.view(data.shape[0], -1)) + log_score(data, z)[0]
        xi = torch.Tensor( data.shape[0], cmd_args.latent_dim ).normal_()
        if cmd_args.ctx == 'gpu':
            xi = xi.cuda()
        z_init = encoder(data, xi)

        if cmd_args.unroll_test:
            best_z = optimize_variable(z_init, fz, EuclideanDist, nsteps = cmd_args.unroll_steps)
        else:
            best_z = z_init

        loss = fz(best_z)
        nu.load_state_dict(bak_nu_dict)
        opt_nu.load_state_dict(bak_opt_dict)
        recon_batch = decoder(best_z)

        test_loss += (loss.item() + cur_loss) * data.shape[0]            
            
        if i == 0:
            n = min(data.size(0), 8)
            comparison = torch.cat([data[:n],
                                recon_batch.view(cmd_args.batch_size, 3, 64, 64)[:n]])
            save_image(comparison.data.cpu(),
                    '%s/vae_reconstruction_' % cmd_args.save_dir + str(epoch) + '.png', nrow=n)
        break
Beispiel #6
0
    for iter in pbar:
        encoder.train()
        idx = torch.LongTensor(cmd_args.batch_size).random_(0, 4)
        data = torch.zeros(cmd_args.batch_size, 4)
        data.scatter_(1, idx.view(-1, 1), 1)

        if cmd_args.ctx == 'gpu':
            data = data.cuda()

        fz = lambda z: binary_cross_entropy(decoder(z), data) + log_score(
            data, z)[0]
        xi = torch.Tensor(data.shape[0], 2).normal_()
        z_init = encoder(data, xi)

        best_z = optimize_variable(z_init,
                                   fz,
                                   EuclideanDist,
                                   nsteps=cmd_args.unroll_steps)
        optimizer.zero_grad()
        loss = fz(best_z)
        loss.backward()
        optimizer.step()

        pbar.set_description('minibatch loss: %.4f' % (loss.item()))

        if iter % 100 == 0:
            encoder_func = lambda x: encoder(
                x,
                torch.Tensor(x.shape[0], 2).normal_())
            create_scatter(x_test_list,
                           encoder_func,
                           savepath=os.path.join(cmd_args.save_dir,