Example #1
0
def main(args):
    # Check if the output folder is exist
    if not os.path.exists('./vae_results/'):
        os.mkdir('./vae_results/')

    # Load data
    torch.manual_seed(args.seed)
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=True, download=True,
                       transform=transforms.ToTensor()),
        batch_size=args.batch_size, shuffle=True, **kwargs)

    # Load model
    model = VAE().cuda() if torch.cuda.is_available() else VAE()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # Train and generate sample every epoch
    loss_list = []
    for epoch in range(1, args.epochs + 1):
        model.train()
        _loss = train(epoch, model, train_loader, optimizer)
        loss_list.append(_loss)
        model.eval()
        sample = torch.randn(64, 20)
        sample = Variable(sample).cuda() if torch.cuda.is_available() else Variable(sample)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 1, 28, 28).data, 'vae_results/sample_' + str(epoch) + '.png')
    plt.plot(range(len(loss_list)), loss_list, '-o')
    plt.savefig('vae_results/vae_loss_curve.png')
Example #2
0
def main(config):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    model = VAE().cuda()

    state = torch.load(config.ckp_path)
    model.load_state_dict(state['state_dict'])

    os.makedirs(os.path.dirname(config.save_path), exist_ok=True)

    torch.manual_seed(66666)
    np.random.seed(66666)
    random.seed(66666)
    z = torch.randn((32, 512)).cuda()
    predict = model.decode(z)
    torchvision.utils.save_image(predict.data,
                                 config.save_path,
                                 nrow=8,
                                 normalize=True)
        test_loss /= len(test_loader)
        print('====> Test set loss: {:.4f}'.format(test_loss))
        return test_loss

    best_loss = sys.maxint
    for epoch in range(1, args.epochs + 1):
        train(epoch)
        loss = test()

        is_best = loss < best_loss
        best_loss = min(loss, best_loss)

        save_checkpoint(
            {
                'state_dict': vae.state_dict(),
                'best_loss': best_loss,
                'n_latents': args.n_latents,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            folder='./trained_models')

        if is_best:
            sample = Variable(torch.randn(64, 20))
            if args.cuda:
                sample = sample.cuda()

            sample = vae.decode(sample).cpu()
            save_image(sample.data.view(64, 1, 28, 28),
                       './results/sample_image.png')
Example #4
0
    # visualize reconst and free sample
    print("plotting imgs...")
    with torch.no_grad():
        val_iter = val_loader.__iter__()

        # reconstruct 25 imgs
        imgs = val_iter._get_batch()[1][0][:25]
        if args.cuda:
            imgs = imgs.cuda()
        imgs_reconst, mu, logvar = model(imgs)

        # sample 25 imgs
        noises = torch.randn(25, model.nz, 1, 1)
        if args.cuda:
            noises = noises.cuda()
        samples = model.decode(noises)

        def write_image(tag, images):
            """
            write the resulting imgs to tensorboard.
            :param tag: The tag for tensorboard
            :param images: the torch tensor with range (-1, 1). [9, 3, 256, 256]
            """
            # make it from 0 to 255
            images = (images + 1) / 2
            grid = make_grid(images, nrow=5, padding=20)
            writer.add_image(tag, grid.detach(), global_step=epoch + 1)

        write_image("origin", imgs)
        write_image("reconst", imgs_reconst)
        write_image("samples", samples)
Example #5
0
        xi = tf.pad(x_ul_raw[i,:,:,:], [[2,2],[2,2],[0,0]])
        xi = tf.random_crop(xi, [32,32,3])
        xi = tf.image.random_flip_left_right(xi)
        x_ul.append(xi)
    x_ul = tf.stack(x_ul, axis=0)
else:
    x = x_raw
    x_ul = x_ul_raw

vae = VAE(args.latent_dim)
net = Net()
out = net.classifier('net', x, keep_prob=args.keep_prob, is_training=True, update_batch_stats=True)
out_ul = net.classifier('net', x_ul, keep_prob=args.keep_prob, is_training=True, update_batch_stats=False)
mu, logvar = vae.encode(x_ul, False)
z = vae.reparamenterize(mu, logvar, False)
x_recon = vae.decode(z, False)

r0 = tf.zeros_like(z, name='zero_holder')
x_recon_r0 = vae.decode(z+r0, False)
diff2 = 0.5 * tf.reduce_sum((x_recon - x_recon_r0)**2, axis=[1,2,3])
diffJaco = tf.gradients(diff2, r0)[0]
def normalizevector(r):
    shape = tf.shape(r)
    r = tf.reshape(r, [shape[0],-1])
    r /= (1e-12+tf.reduce_max(tf.abs(r), axis=1, keepdims=True))
    r / tf.sqrt(tf.reduce_sum(r**2, axis=1, keepdims=True)+1e-6)
    return tf.reshape(r, shape)

# power method
r_adv = normalizevector(tf.random_normal(shape=tf.shape(z)))
for j in range(1):
Example #6
0
                    data = data.cuda()
                data = Variable(data, volatile=True)
                recon_batch, mu, logvar, z = model(data)
                BCE = criterion_mse(recon_batch.view(-1, 2, args.window_size), data.view(-1, 2, args.window_size))
                errors.append(BCE.cpu().numpy())
                batches.append(recon_batch.cpu().numpy())
                zs.append(z.cpu().numpy())
        batches = np.concatenate(batches, 0)
        zs = np.concatenate(zs, 0)
        errors = np.concatenate(errors, 0)
        np.save(os.path.expandvars(args.out_dir)+"/out_all_epoch_"+str(epoch)+".npy", batches)
        np.save(os.path.expandvars(args.out_dir)+"/out_all_zs_epoch_"+str(epoch)+".npy", zs)
        np.save(os.path.expandvars(args.out_dir)+"/out_all_errors_epoch_"+str(epoch)+".npy", errors)

losses = []
for epoch in range(1, args.epochs + 1):
    tl = train(epoch)
    testl = test(epoch)
    if epoch == args.epochs:
        sample = Variable(torch.randn(64, args.hidden_size))
        if args.cuda:
            sample = sample.cuda()
        sample = model.decode(sample).cpu()
        np.save("sample.npy", sample.detach().numpy())
    losses.append([tl, testl])
    if epoch % 2 == 0:
        torch.save(model.state_dict(), os.path.expandvars(args.out_dir)+"/model_epoch_"+str(epoch)+".pth")
        all_out(epoch)
np.save(os.path.expandvars(args.out_dir)+"/losses.npy", np.array(losses))

Example #7
0
def test_rnn(epi):
    mus, logvars = load_init_z()

    vae = VAE()
    vae.load_state_dict(torch.load(cfg.vae_save_ckpt)['model'])

    model = RNNModel()
    model.load_state_dict(torch.load(cfg.rnn_save_ckpt)['model'])

    controller = Controller()
    controller.load_state_dict(torch.load(cfg.ctrl_save_ckpt)['model'])

    model.reset()
    z = sample_init_z(mus, logvars)
    frames = []

    for step in range(cfg.max_steps):
        z = torch.from_numpy(z).float().unsqueeze(0)
        curr_frame = vae.decode(z).detach().numpy()

        frames.append(curr_frame.transpose(0, 2, 3, 1)[0] * 255.0)
        # cv2.imshow('game', frames[-1])
        # k = cv2.waitKey(33)

        inp = torch.cat((model.hx.detach(), model.cx.detach(), z), dim=1)
        y = controller(inp)
        y = y.item()
        action = encode_action(y)

        logmix, mu, logstd, done_p = model.step(z.unsqueeze(0),
                                                action.unsqueeze(0))

        # logmix = logmix - reduce_logsumexp(logmix)
        logmix_max = logmix.max(dim=1, keepdim=True)[0]
        logmix_reduce_logsumexp = (logmix - logmix_max).exp().sum(
            dim=1, keepdim=True).log() + logmix_max
        logmix = logmix - logmix_reduce_logsumexp

        # Adjust temperature
        logmix = logmix / cfg.temperature
        logmix -= logmix.max(dim=1, keepdim=True)[0]
        logmix = F.softmax(logmix, dim=1)

        m = Categorical(logmix)
        idx = m.sample()

        new_mu = torch.FloatTensor([mu[i, j] for i, j in enumerate(idx)])
        new_logstd = torch.FloatTensor(
            [logstd[i, j] for i, j in enumerate(idx)])
        z_next = new_mu + new_logstd.exp() * torch.randn_like(
            new_mu) * np.sqrt(cfg.temperature)

        z = z_next.detach().numpy()
        if done_p.squeeze().item() > 0:
            break

    frames = [cv2.resize(frame, (256, 256)) for frame in frames]

    print('Episode {}: RNN Reward {}'.format(epi, step))
    write_video(frames, 'rnn_{}.avi'.format(epi), (256, 256))
    os.system('mv rnn_{}.avi /home/bzhou/Dropbox/share'.format(epi))
Example #8
0
class Solver(object):
    def __init__(self, trainset_loader, config):
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda' if self.use_cuda else 'cpu')
        self.trainset_loader = trainset_loader
        self.nz = config.nz
        self.n_epochs = config.n_epochs
        self.resume_iters = config.resume_iters
        self.lr = config.lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.kld_factor = config.kld_factor
        self.exp_name = config.name
        os.makedirs(config.ckp_dir, exist_ok=True)
        self.ckp_dir = os.path.join(config.ckp_dir, self.exp_name)
        os.makedirs(self.ckp_dir, exist_ok=True)
        self.example_dir = os.path.join(self.ckp_dir, "output")
        os.makedirs(self.example_dir, exist_ok=True)
        self.log_interval = config.log_interval
        self.save_interval = config.save_interval
        self.use_wandb = config.use_wandb
        
        self.build_model()

    def build_model(self):
        self.model = VAE(z_dim=self.nz).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr,  betas=[self.beta1, self.beta2] )

    def save_checkpoint(self, step):
        state = {'state_dict': self.model.state_dict(),
                 'optimizer' : self.optimizer.state_dict()}
        new_checkpoint_path = os.path.join(self.ckp_dir, '{}-vae.pth'.format(step + 1))
        torch.save(state, new_checkpoint_path)
        print('model saved to %s' % new_checkpoint_path)

    def load_checkpoint(self, resume_iters):
        print('Loading the trained models from step {}...'.format(resume_iters))
        new_checkpoint_path = os.path.join(self.ckp_dir, '{}-vae.pth'.format(resume_iters))
        state = torch.load(new_checkpoint_path)
        self.model.load_state_dict(state['state_dict'])
        self.optimizer.load_state_dict(state['optimizer'])
        print('model loaded from %s' % new_checkpoint_path)
    
    def train(self):
        iteration = 0
        torch.manual_seed(66666)
        fixed_noise = torch.randn((32, self.nz)).cuda()

        if self.resume_iters:
            print("resuming step %d ..."% self.resume_iters)
            iteration = self.resume_iters
            self.load_checkpoint(self.resume_iters)

        for ep in range(self.n_epochs):
            self.model.train()  # set training mode

            mse_loss_t = 0.0
            kld_loss_t = 0.0

            for batch_idx, (data, _) in enumerate(self.trainset_loader):
                data = data.to(self.device)
                self.optimizer.zero_grad()
                rec, mu, logvar = self.model(data)
                mse_loss = F.mse_loss(rec, data)
                kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

                mse_loss_t += mse_loss.item()
                kld_loss_t += kld_loss.item()

                loss = mse_loss + self.kld_factor * kld_loss
                loss.backward()

                self.optimizer.step()

                if (iteration + 1) % self.log_interval == 0:
                    print('Epoch: {:3d} [{:5d}/{:5d} ({:3.0f}%)]\tIteration: {:5d}\tMSE: {:.6f}\tKLD: {:.6f}'.format(
                        ep, (batch_idx + 1) * len(data), len(self.trainset_loader.dataset),
                        100. * (batch_idx + 1) / len(self.trainset_loader), iteration + 1, mse_loss.item(), kld_loss.item()))

                if (iteration + 1) % self.save_interval == 0 and iteration > 0:
                    self.save_checkpoint(iteration)
                    g_example = self.model.decode(fixed_noise)
                    g_example_path = os.path.join(self.example_dir, '%d.png' % (iteration+1))
                    torchvision.utils.save_image(g_example.data, g_example_path, nrow=8, normalize=True)

                iteration += 1

            print('Epoch: {:3d} [{:5d}/{:5d} ({:3.0f}%)]\tIteration: {:5d}\tMSE: {:.6f}\tKLD: {:.6f}\n'.format(
                ep, len(self.trainset_loader.dataset), len(self.trainset_loader.dataset), 100., iteration,
                mse_loss_t / len(self.trainset_loader), kld_loss_t / len(self.trainset_loader)))

            if self.use_wandb:
                import wandb
                wandb.log({"MSE": mse_loss_t / len(self.trainset_loader),
                           "KLD": kld_loss_t / len(self.trainset_loader)})
            

        self.save_checkpoint(iteration)
        g_example = self.model.decode(fixed_noise)
        g_example_path = os.path.join(self.example_dir, '%d.png' % (iteration+1))
        torchvision.utils.save_image(g_example.data, g_example_path, nrow=8, normalize=True)
Example #9
0
        # Backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(
                "Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}"
                .format(epoch + 1, num_epochs, i + 1, len(data_loader),
                        reconst_loss.item(), kl_div.item()))

    with torch.no_grad():
        # Save the sampled images
        z = torch.randn(batch_size, z_dim).cuda()
        out = model.decode(z).view(-1, 1, 28, 28)
        save_image(
            out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch + 1)))

        # Save the reconstructed images
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28),
                              out.view(-1, 1, 28, 28)],
                             dim=3)
        save_image(
            x_concat,
            os.path.join(sample_dir, 'reconst-{}.png'.format(epoch + 1)))

torch.save(model.state_dict(), 'models/MNIST_EnD.pth')
Example #10
0
def run():
    dataset = AugMNISTDataset()
    dataloader = torch.utils.data.DataLoader(dataset,
                                             shuffle=True,
                                             batch_size=args.batch_size,
                                             num_workers=8)
    vae = VAE(latent_dim=args.latent_dim).to(args.device)
    discrim = nn.GRU(784, args.discrim_hidden, 1).to(args.device)
    discrim.flatten_parameters()
    classifier = nn.Sequential(nn.ReLU(),
                               nn.Linear(args.discrim_hidden,
                                         args.latent_dim)).to(args.device)
    summary(vae, (784, ))

    discrim_opt = torch.optim.Adam(chain(discrim.parameters(),
                                         classifier.parameters()),
                                   lr=args.discrim_lr)
    enc_opt = torch.optim.Adam(vae.e_params(), lr=args.lr)
    dec_opt = torch.optim.Adam(vae.d_params(), lr=args.lr)
    step = 0
    for epoch in range(args.epochs):
        for idx, sample in enumerate(dataloader):
            image = sample['image'].to(args.device).view(-1, 784)
            mu, logvar = vae.encode(image)
            z = vae.reparameterize(mu, logvar)
            x_hat = vae.decode(z)
            mse, kld = loss_function(x_hat, image, mu, logvar)
            vae_loss = mse + args.beta * kld

            vae_loss.backward()
            enc_grad = torch.nn.utils.clip_grad_norm(vae.e_params(), 100)
            enc_opt.step()
            enc_opt.zero_grad()

            if args.reg_coef:
                samples, labels, z = generate_subset_samples(
                    args.batch_size // 4, args.n_samples, args.n_sample_dims,
                    args.latent_dim, vae, args.device)
                discrim_out = discrim(samples.detach())[1][-1]
                logits = classifier(discrim_out)
                discrim_loss = F.binary_cross_entropy_with_logits(
                    logits, labels)
                discrim_loss.backward()
                discrim_grad = torch.nn.utils.clip_grad_norm(
                    chain(discrim.parameters(), classifier.parameters()), 100)
                discrim_opt.step()
                discrim_opt.zero_grad()

                dec_pre_grad = torch.nn.utils.clip_grad_norm(
                    vae.d_params(), 100)

                discrim_out = discrim(samples)[1][-1]
                logits = classifier(discrim_out)
                discrim_loss = args.reg_coef * F.binary_cross_entropy_with_logits(
                    logits, labels)
                discrim_loss.backward()
                discrim_opt.zero_grad()
                dec_post_grad = torch.nn.utils.clip_grad_norm(
                    vae.d_params(), 100)
                dec_opt.step()
                dec_opt.zero_grad()
            else:
                dec_post_grad = torch.nn.utils.clip_grad_norm(
                    vae.d_params(), 100)
                dec_opt.step()
                dec_opt.zero_grad()

            if step % 500 == 0:
                for k in range(args.latent_dim):
                    z = resample_kth_z(k, args.n_show, args.latent_dim,
                                       args.device)
                    save_samples(vae, z, f'z{k}_vae_output.png')

                z = torch.randn(args.n_show,
                                args.latent_dim,
                                device=args.device)
                save_samples(vae, z, 'vae_output.png')

                original_grid = tv.utils.make_grid(
                    sample['image'][:args.n_show], int(args.n_show**0.5))
                tv.utils.save_image(original_grid, 'images.png')

                print(f'step := {step}')
                if args.reg_coef:
                    discrim_acc = ((logits > 0) == labels).float().mean()
                    print(f'discrim_loss := {discrim_loss.item()}')
                    print(f'discrim_acc := {discrim_acc}')
                    print(f'grad/discrim := {discrim_grad}')
                    print(f'grad/dec_pre := {dec_pre_grad}')
                print(f'vae_loss := {vae_loss}')
                print(f'mse := {mse}')
                print(f'kld := {kld}')
                print(f'grad/enc := {enc_grad}')
                print(f'grad/dec_post := {dec_post_grad}')

            step += 1
Example #11
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        torch.manual_seed(self.args.seed)
        if self.args.cuda:
            torch.cuda.manual_seed(self.args.seed)

        kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST('./data',
                           train=True,
                           download=True,
                           transform=transforms.ToTensor()),
            batch_size=self.args.batch_size,
            shuffle=True,
            **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('./data',
                           train=False,
                           transform=transforms.ToTensor()),
            batch_size=self.args.batch_size,
            shuffle=True,
            **kwargs)
        self.train_loader = train_loader
        self.test_loader = test_loader

        self.model = VAE()
        if self.args.cuda:
            self.model.cuda()

        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)

    def loss_function(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784))
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        KLD /= self.args.batch_size * 784
        return BCE + KLD

    def train_one_epoch(self, epoch):
        train_loader = self.train_loader
        args = self.args

        self.model.train()
        train_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            data = Variable(data)
            if args.cuda:
                data = data.cuda()
            self.optimizer.zero_grad()
            recon_batch, mu, logvar = self.model(data)
            loss = self.loss_function(recon_batch, data, mu, logvar)
            loss.backward()
            train_loss += loss.data[0]
            self.optimizer.step()
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.data[0] / len(data)))
        print('=====> Epoch: {} Average loss: {:.4f}'.format(
            epoch, train_loss / len(train_loader.dataset)))

    def test(self, epoch):
        test_loader = self.test_loader
        args = self.args

        self.model.eval()
        test_loss = 0
        for i, (data, _) in enumerate(test_loader):
            if args.cuda:
                data = data.cuda()
            data = Variable(data, volatile=True)
            recon_batch, mu, logvar = self.model(data)
            test_loss += self.loss_function(recon_batch, data, mu,
                                            logvar).data[0]
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([
                    data[:n],
                    recon_batch.view(args.batch_size, 1, 28, 28)[:n]
                ])
                fname = 'results/reconstruction_' + str(epoch) + '.png'
                save_image(comparison.data.cpu(), fname, nrow=n)

        test_loss /= len(test_loader.dataset)
        print('=====> Test set loss: {:.4f}'.format(test_loss))

    def train(self):
        args = self.args
        for epoch in range(1, args.epochs + 1):
            self.train_one_epoch(epoch)
            self.test(epoch)
            sample = Variable(torch.randn(64, 20))
            if args.cuda:
                sample = sample.cuda()
            sample = self.model.decode(sample).cpu()
            save_image(sample.data.view(64, 1, 28, 28),
                       './results/sample_' + str(epoch) + '.png')
Example #12
0
    # minibatch optimization with Adam
    for data in dataloader:
        img, _ = data

        # change the images to be 1D
        img = img.view(img.size(0), -1)

        # get output from network
        out, mu, log_var = vae(img)

        # calculate loss and update network
        loss = F.binary_cross_entropy(out, img) + KL(mu, log_var)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # save images periodically
    if epoch % 10 == 0:
        img = out.data.view(out.size(0), 1, 28, 28)
        save_image(img, './img/' + str(epoch) + '_epochs.png')

    # plot loss
    update_viz(epoch, loss.item())

# generate new random images
input = torch.randn(96, 10)
out = vae.decode(input)
img = out.data.view(96, 1, 28, 28)
save_image(img, './generated_img/img.png')
Example #13
0
    vae.eval()
    imgs, *_ = vae(sample_batch)
    sample_batch_recon = vutils.make_grid(imgs,
                                          nrow=10,
                                          padding=2,
                                          normalize=True)
    writer.add_image('Train/Recon', sample_batch_recon, step)
    imgs, *_ = vae(test_batch)
    test_batch_recon = vutils.make_grid(imgs,
                                        nrow=10,
                                        padding=2,
                                        normalize=True)
    writer.add_image('Test/Recon', test_batch_recon, step)

    # output random samples
    imgs = vae.decode(rand_z)
    rand_samples = vutils.make_grid(imgs, nrow=10, padding=2, normalize=True)
    writer.add_image('Random Z', rand_samples, step)

    if epoch % params['save_epoch'] == 0:
        # save model
        torch.save(
            {
                # 'epoch': epoch,
                'model_state_dict': vae.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss_tot': loss_tot,
                'loss_kl': loss_kl,
                'loss_rec': loss_rec,
            },
            params['save_path'])
        lb = Variable( torch.LongTensor( [label[i]] ) )
        test_labels.append(labelembed.index_select(0, lb))
    break

if args.model == "VAE":
    for i in xrange(iteration):
        img = test_imgs[i]
        mu, var = model.encode(img.view(1, 784))
        center = test_centers[i]

        maxidx = 1000
        minloss = 10000000
        for j in range(maxidx):
            #z_mu = model.reparametrize(mu, var)
            z = Variable(torch.FloatTensor(1,20).normal_())
            recon = model.decode(z).view(1, 28, -1)
            recon1 = occludeimg_with_center(recon.view(1, 28, -1), center)
            loss = reconstruction_function(recon1.view(-1), img.view(-1))
            if loss < minloss:
                minloss = loss
                min_recon = recon

        compare(test_originals[i].data, img.data, min_recon.data)

elif args.model == "VAE_INC":
    for i in xrange(iteration):
        maxidx = 1000
        minloss = 10000000
        img = test_imgs[i]
        mu, var = model.encode(img.view(1, 784))
        center = test_centers[i]