def test(epoch, agg): model.eval() b_loss, b_mlik = 0., 0. with torch.no_grad(): for i, (data, labels) in enumerate(test_loader): data = data.to(device) qz_x, px_z, lik, kl, loss = loss_function(model, data, K=args.K, beta=args.beta, components=True) if epoch == args.epochs and args.iwae_samples > 0: mlik = objectives.iwae_objective(model, data, K=args.iwae_samples) b_mlik += mlik.sum(-1).item() b_loss += loss.item() if i == 0: model.reconstruct(data, runPath, epoch) agg['test_loss'].append(b_loss / len(test_loader.dataset)) agg['test_mlik'].append(b_mlik / len(test_loader.dataset)) print('====> Test loss: {:.4f} mlik: {:.4f}'.format( agg['test_loss'][-1], agg['test_mlik'][-1]))
def test(beta, alpha, agg): model.eval() b_negloss, b_recon, b_kl, b_reg, b_mlike = 0., 0., 0., 0., 0. zs_mean = torch.zeros(len(test_loader.dataset), D, device=device) zs_std = torch.zeros(len(test_loader.dataset), D, device=device) zs2_mean = torch.zeros(len(test_loader.dataset), D, device=device) L = test_loader.dataset[0][1].view(-1).size(-1) ys = torch.zeros(len(test_loader.dataset), L, device=device) for i, (data, labels) in enumerate(test_loader): data = data.to(device) qz_x, px_z, zs = model(data, 1) negloss, recon, kl, reg = objective( model, data, K=args.K, beta=beta, alpha=alpha, regs=(regs if args.alpha > 0 else None), components=True) b_negloss += negloss.item() b_recon += recon.item() b_kl += kl.item() b_reg += reg.item() zs_mean[(B * i):(B * (i + 1)), :] = qz_x.mean zs_std[(B * i):(B * (i + 1)), :] = qz_x.stddev ys[(B * i):(B * (i + 1)), :] = labels.view(-1, L) if cmds.disentanglement: # change measure if prior is not normal (along optimal transport map) if model.pz == torch.distributions.studentT.StudentT: df, pz_mean, pz_scale = model.pz_params u = scipy.stats.t.cdf(qz_x.mean.data.cpu().numpy(), df=df.data.cpu().numpy(), loc=pz_mean.data.cpu().numpy(), scale=pz_scale.data.cpu().numpy()) qz_x_mean = torch.distributions.Normal( loc=pz_mean, scale=pz_scale).icdf( torch.tensor(u, dtype=torch.float).to(device)) else: qz_x_mean = qz_x.mean zs2_mean[(B * i):(B * (i + 1)), :] = qz_x_mean.view(B, D) if cmds.logp: b_mlike += objectives.iwae_objective( model, data, cmds.iwae_samples).sum().item() agg['test_loss'].append(-b_negloss / N) agg['test_recon'].append(b_recon / N) agg['test_kl'].append(b_kl / N) agg['test_reg'].append(b_reg / N) print('Loss: {:.1f} Recon: {:.1f} KL: {:.1f} Reg: {:.3f}'.format( agg['test_loss'][-1], agg['test_recon'][-1], agg['test_kl'][-1], agg['test_reg'][-1])) model.posterior_plot(zs_mean, zs_std, runPath, args.epochs) if cmds.disentanglement: dis = compute_disentanglement(zs_mean, ys).item() agg['test_disentanglement'].append(dis) dis2 = compute_disentanglement(zs2_mean, ys).item() agg['test_disentanglement2'].append(dis2) print('Disentanglement: {:.3f} (wo OT {:.3f})'.format( agg['test_disentanglement2'][-1], agg['test_disentanglement'][-1])) if cmds.sparsity: agg['test_sparsity'].append(compute_sparsity(zs_mean, norm=True)) print('Sparsity: {:.3f}'.format(agg['test_sparsity'][-1])) labels = [ 'T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot' ] C = len(ys.unique()) zs_mean_mag_avg = torch.zeros(C, zs_mean.size(-1)) for c in sorted(list(ys.unique())): idx = (ys == c).view(-1) zs_mean_mag_avg[int(c)] = zs_mean[idx].abs().mean(0) plot_latent_magnitude(zs_mean_mag_avg[range(10), :], labels=labels, path=runPath + '/plot_sparsity') if cmds.logp: agg['test_mlik'].append(b_mlike / len(test_loader.dataset)) print('Marginal Log Likelihood (IWAE, K = {}): {:.4f}'.format( cmds.iwae_samples, agg['test_mlik'][-1]))