def analyse(data_loader, plot=False): bpds = [] batch_idx = 0 with torch.no_grad(): for data, _ in data_loader: batch_idx += 1 if args.cuda: data = data.cuda() data = data.view(-1, *args.input_size) loss, batch_bpd, bpd_per_prior, pz, z, pys, ys, ldj = \ model(data) loss = torch.mean(loss).item() batch_bpd = torch.mean(batch_bpd).item() bpds.append(batch_bpd) bpd = np.mean(bpds) with torch.no_grad(): if not testing and plot: x_sample = model_sample.sample(n=100) try: plot_reconstructions(x_sample, bpd, loss_type, epoch, args) except: print('Not plotting') return bpd
def analyse(data_loader, plot=False): bpds = [] losses = [] with torch.no_grad(): for batch_idx, (data, ) in enumerate(data_loader): if args.device == 'cuda': data = data.cuda() data = data.view(-1, *args.input_size) data = data.to(dtype=torch.float64, device=args.device) result = model(data) loss = -torch.mean(result['total_logd']) bpd = -torch.mean(result['total_logd']) / (64 * 64 * 3) batch_loss = -torch.mean(result['total_logd']).item() batch_bpd = (-torch.mean(result['total_logd']) / (64 * 64 * 3)).item() if not np.isnan(batch_bpd): bpds.append(batch_bpd) losses.append(batch_loss) tmp = 'Validation Epoch: {:3d} \tLoss: {:11.6f}\tbpd: {:8.6f}' print(tmp.format(epoch, batch_loss, batch_bpd)) bpd = np.mean(bpds) loss = np.mean(losses) with torch.no_grad(): if not testing and plot: x_sample = model_sample.sample(n=100) try: plot_reconstructions(x_sample, bpd, loss_type, epoch, args) except: print('Not plotting') return bpd, loss
def evaluate(data_loader, model, args, testing=False, file=None, epoch=0): model.eval() loss = 0. batch_idx = 0 bpd = 0. if args.input_type == 'binary': loss_type = 'elbo' else: loss_type = 'bpd' with torch.no_grad(): for data, _ in data_loader: batch_idx += 1 if args.cuda: data = data.cuda() data = Variable(data) data = data.view(-1, *args.input_size) x_mean, z_mu, z_var, ldj, z0, zk = model(data) batch_loss, rec, kl, batch_bpd = calculate_loss( x_mean, data, z_mu, z_var, z0, zk, ldj, args) bpd += batch_bpd # loss += batch_loss.data[0] loss += batch_loss.data.item() # PRINT RECONSTRUCTIONS if batch_idx == 1 and testing is False: plot_reconstructions(data, x_mean, batch_loss, loss_type, epoch, args) loss /= len(data_loader) bpd /= len(data_loader) # Compute log-likelihood if testing: test_data = Variable(data_loader.dataset.data_tensor, volatile=True) if args.cuda: test_data = test_data.cuda() print('Computing log-likelihood on test set') model.eval() if args.dataset == 'caltech': log_likelihood, nll_bpd = calculate_likelihood(test_data, model, args, S=2000, MB=1000) else: log_likelihood, nll_bpd = calculate_likelihood(test_data, model, args, S=5000, MB=1000) else: log_likelihood = None nll_bpd = None if args.input_type in ['multinomial']: bpd = loss / (np.prod(args.input_size) * np.log(2.)) if file is None: if testing: print('====> Test set loss: {:.4f}'.format(loss)) print( '====> Test set log-likelihood: {:.4f}'.format(log_likelihood)) if args.input_type != 'binary': print('====> Test set bpd (elbo): {:.4f}'.format(bpd)) print('====> Test set bpd (log-likelihood): {:.4f}'.format( log_likelihood / (np.prod(args.input_size) * np.log(2.)))) else: print('====> Validation set loss: {:.4f}'.format(loss)) if args.input_type in ['multinomial']: print('====> Validation set bpd: {:.4f}'.format(bpd)) else: with open(file, 'a') as ff: if testing: print('====> Test set loss: {:.4f}'.format(loss), file=ff) print('====> Test set log-likelihood: {:.4f}'.format( log_likelihood), file=ff) if args.input_type != 'binary': print('====> Test set bpd: {:.4f}'.format(bpd), file=ff) print('====> Test set bpd (log-likelihood): {:.4f}'.format( log_likelihood / (np.prod(args.input_size) * np.log(2.))), file=ff) else: print('====> Validation set loss: {:.4f}'.format(loss), file=ff) if args.input_type != 'binary': print('====> Validation set bpd: {:.4f}'.format( loss / (np.prod(args.input_size) * np.log(2.))), file=ff) if not testing: return loss, bpd else: return log_likelihood, nll_bpd