示例#1
0
def get_init_posterior(data):
    fz = lambda z, mu, logvar: loss_function(decoder(z), data, mu, logvar)
    if cmd_args.unroll_test:
        best_z, mu, logvar = optimize_gaussian(data, encoder, fz, inner_opt_class, nsteps = cmd_args.unroll_steps, training = True)
    else:
        best_z, mu, logvar = encoder(data)
    return best_z, mu, logvar
示例#2
0
def test(epoch):
    encoder.eval()
    inner_opt.set_freeze_flag(True)
    test_loss = 0
    for i, (data, _) in tqdm(enumerate(test_loader)):
        data = convert_data(data)
        fz = lambda z, mu, logvar: loss_function(decoder(z), data, mu, logvar)
        if cmd_args.unroll_test:
            inner_opt.zero_grad()
            bak_dict = encoder.diff_var_dict()
            best_z, mu, logvar = optimize_func(data, encoder, fz, inner_opt, nsteps = cmd_args.unroll_steps)
            encoder.load_diff_var_dict(bak_dict)
        else:
            best_z, mu, logvar = encoder(data)

        loss = fz(best_z, mu, logvar)
        recon_batch = decoder(best_z)
        test_loss += loss.item() * data.shape[0]
        if i == 0:
            n = min(data.size(0), 8)
            comparison = torch.cat([data[:n],
                                  recon_batch.view(-1, 1, cmd_args.img_size, cmd_args.img_size)[:n]])
            save_image(comparison.data.cpu(),
                     '%s/parametric_op_reconstruction_' % cmd_args.save_dir + str(epoch) + '.png', nrow=n)
    inner_opt.set_freeze_flag(False)
    test_loss /= len(test_loader.dataset)
    msg = 'test epoch %d, average loss %.4f' % (epoch, test_loss)
    print(msg)
    return test_loss
示例#3
0
def train(epoch):
    encoder.train()
    train_loss = 0
    pbar = tqdm(celeb_loader)
    num_mini_batches = 0
    for (data, _) in pbar:
        data = Variable(data)
        if cmd_args.ctx == 'gpu':
            data = data.cuda()

        optimizer.zero_grad()
        inner_opt.zero_grad()

        fz = lambda z, mu, logvar: loss_function(decoder(z).view(data.shape[0], -1), data.view(data.shape[0], -1), mu, logvar)
        best_z, mu, logvar = optimize_func(data, encoder, fz, inner_opt, nsteps = cmd_args.unroll_steps)
        
        loss = fz(best_z, mu, logvar)
        loss.backward()

        train_loss += loss.item()
        optimizer.step()

        pbar.set_description('minibatch loss: %.4f' % loss.item())
        num_mini_batches += 1
    print('Epoch %d, average loss %.4f' % (epoch, train_loss / num_mini_batches))
    return train_loss / num_mini_batches
示例#4
0
def test(epoch):
    encoder.eval()
    inner_opt.set_freeze_flag(True)
    test_loss = 0
    for i, (data, _) in enumerate(test_loader):
        if cmd_args.ctx == 'gpu':
            data = data.cuda()
        data = Variable(data)

        fz = lambda z, mu, logvar: loss_function(decoder(z).view(data.shape[0], -1), data.view(data.shape[0], -1), mu, logvar)
        if cmd_args.unroll_test:
            inner_opt.zero_grad()
            bak_dict = encoder.diff_var_dict()
            best_z, mu, logvar = optimize_func(data, encoder, fz, inner_opt, nsteps = cmd_args.unroll_steps)
            encoder.load_diff_var_dict(bak_dict)
        else:
            best_z, mu, logvar = encoder(data)

        loss = fz(best_z, mu, logvar)
        recon_batch = decoder(best_z)
        test_loss += loss.item() * data.shape[0]                    
            
        if i == 0:
            n = min(data.size(0), 8)
            comparison = torch.cat([data[:n],
                                recon_batch.view(cmd_args.batch_size, 3, 64, 64)[:n]])
            save_image(comparison.data.cpu(),
                    '%s/vae_reconstruction_' % cmd_args.save_dir + str(epoch) + '.png', nrow=n)
        break
    inner_opt.set_freeze_flag(False)
示例#5
0
def get_init_posterior(data):
    fz = lambda z, mu, logvar: loss_function(decoder(z).view(data.shape[0], -1), data.view(data.shape[0], -1), mu, logvar)
    if cmd_args.unroll_test:
        inner_opt.zero_grad()
        bak_dict = encoder.diff_var_dict()
        best_z, mu, logvar = optimize_func(data, encoder, fz, inner_opt, nsteps = cmd_args.unroll_steps)
        encoder.load_diff_var_dict(bak_dict)
    else:
        best_z, mu, logvar = encoder(data)
    return best_z, mu, logvar
示例#6
0
def train(epoch):
    train_loss = 0
    encoder.train()
    pbar = tqdm(train_loader)
    num_mini_batches = 0
    for (data, _) in pbar:
        data = convert_data(data)

        optimizer.zero_grad()
        fz = lambda z, mu, logvar: loss_function(decoder(z), data, mu, logvar)
        best_z, mu, logvar = optimize_gaussian(data, encoder, fz, inner_opt_class, nsteps = cmd_args.unroll_steps, training=True)
                
        loss = fz(best_z, mu, logvar)
        loss.backward()
        
        train_loss += loss.item()
        optimizer.step()
        recon_loss = binary_cross_entropy(decoder(mu), data)
        pbar.set_description('minibatch loss: %.4f, recon: %.4f' % (loss.item(), recon_loss.item()))
        num_mini_batches += 1
    msg = 'train epoch %d, average loss %.4f' % (epoch, train_loss / num_mini_batches)
    print(msg)
示例#7
0
def train(epoch):
    train_loss = 0
    encoder.train()
    pbar = tqdm(train_loader)
    num_mini_batches = 0
    for (data, _) in pbar:
        data = convert_data(data)
     
        optimizer.zero_grad()
        inner_opt.zero_grad()
        fz = lambda z, mu, logvar: loss_function(decoder(z), data, mu, logvar)
        best_z, mu, logvar = optimize_func(data, encoder, fz, inner_opt, nsteps = cmd_args.unroll_steps)
                
        loss = fz(best_z, mu, logvar)
        loss.backward()
        
        train_loss += loss.item()
        optimizer.step()

        pbar.set_description('minibatch loss: %.4f' % loss.item())
        num_mini_batches += 1
    msg = 'train epoch %d, average loss %.4f' % (epoch, train_loss / num_mini_batches)
    print(msg)
示例#8
0
def test(epoch):
    encoder.eval()
    test_loss = 0
    for i, (data, _) in tqdm(enumerate(test_loader)):
        data = convert_data(data)
        fz = lambda z, mu, logvar: loss_function(decoder(z), data, mu, logvar)
        if cmd_args.unroll_test:
            best_z, mu, logvar = optimize_gaussian(data, encoder, fz, inner_opt_class, nsteps = cmd_args.unroll_steps, training = True)
        else:
            best_z, mu, logvar = encoder(data)

        loss = fz(best_z, mu, logvar)
        recon_batch = decoder(best_z)
        test_loss += loss.item() * data.shape[0]
        if i == 0:
            n = min(data.size(0), 8)
            comparison = torch.cat([data[:n],
                                  recon_batch.view(cmd_args.batch_size, 1, cmd_args.img_size, cmd_args.img_size)[:n]])
            save_image(comparison.data.cpu(),
                     '%s/unroll_gauss_reconstruction_' % cmd_args.save_dir + str(epoch) + '.png', nrow=n)
    test_loss /= len(test_loader.dataset)
    msg = 'test epoch %d, average loss %.4f' % (epoch, test_loss)
    print(msg)
    return test_loss