def eval(data_loader, k): fgen.eval() nent = 0 nll_mc = 0 nll_iw = 0 num_insts = 0 for i, (data, _) in enumerate(data_loader): data = data.to(device, non_blocking=True) # [batch, channels, H, W] batch, c, h, w = data.size() x = preprocess(data, n_bits) # [batch, k] noise, log_probs_posterior = fgen.dequantize(x, nsamples=k) # [batch, k, channels, H, W] data = preprocess(data, n_bits, noise) # [batch * k, channels, H, W] -> [batch * k] -> [batch, k] log_probs = fgen.log_probability(data.view(batch * k, c, h, w)).view(batch, k) # [batch, k] log_iw = log_probs - log_probs_posterior num_insts += batch nent += log_probs_posterior.mean(dim=1).sum().item() nll_mc -= log_iw.mean(dim=1).sum().item() nll_iw += (math.log(k) - torch.logsumexp(log_iw, dim=1)).sum().item() nent = nent / num_insts nepd = nent / (nx * np.log(2.0)) nll_mc = nll_mc / num_insts + np.log(n_bins / 2.) * nx bpd_mc = nll_mc / (nx * np.log(2.0)) nll_iw = nll_iw / num_insts + np.log(n_bins / 2.) * nx bpd_iw = nll_iw / (nx * np.log(2.0)) print('Avg NLL: {:.2f}, NENT: {:.2f}, IW: {:.2f}, BPD: {:.4f}, NEPD: {:.4f}, BPD_IW: {:.4f}'.format( nll_mc, nent, nll_iw, bpd_mc, nepd, bpd_iw)) return nll_mc, nent, nll_iw, bpd_mc, nepd, bpd_iw
def reconstruct(epoch): print('reconstruct') fgen.eval() n = 16 np.random.shuffle(test_index) img, _ = get_batch(test_data, test_index[:n]) img = preprocess(img.to(device), n_bits) z, _ = fgen.encode(img) img_recon, _ = fgen.decode(z) abs_err = img_recon.add(img * -1).abs() print('Err: {:.4f}, {:.4f}'.format(abs_err.max().item(), abs_err.mean().item())) img = postprocess(img, n_bits) img_recon = postprocess(img_recon, n_bits) comparison = torch.cat([img, img_recon], dim=0).cpu() reorder_index = torch.from_numpy(np.array([[i + j * n for j in range(2)] for i in range(n)])).view(-1) comparison = comparison[reorder_index] image_file = 'reconstruct{}.png'.format(epoch) save_image(comparison, os.path.join(result_path, image_file), nrow=16)
def train(epoch, k): print('Epoch: %d (lr=%.6f (%s), patient=%d)' % (epoch, lr, opt, patient)) fgen.train() nll = 0 nent = 0 num_insts = 0 num_back = 0 num_nans = 0 start_time = time.time() for batch_idx, (data, _) in enumerate(train_loader): batch_size = len(data) optimizer.zero_grad() data = data.to(device, non_blocking=True) nll_batch = 0 nent_batch = 0 data_list = [data, ] if batch_steps == 1 else data.chunk(batch_steps, dim=0) for data in data_list: x = preprocess(data, n_bits) # [batch, k] noise, log_probs_posterior = fgen.dequantize(x, nsamples=k) # [batch, k] -> [1] log_probs_posterior = log_probs_posterior.mean(dim=1).sum() # [batch, k, channels, H, W] -> [batch, channels, H, W] data = preprocess(data, n_bits, noise[:, 0:1]).squeeze(1) log_probs = fgen.log_probability(data).sum() loss = (log_probs_posterior - log_probs) / batch_size loss.backward() with torch.no_grad(): nll_batch -= log_probs.item() nent_batch += log_probs_posterior.item() if grad_clip > 0: grad_norm = clip_grad_norm_(fgen.parameters(), grad_clip) else: grad_norm = total_grad_norm(fgen.parameters()) if math.isnan(grad_norm): num_nans += 1 else: optimizer.step() scheduler.step() # exponentialMovingAverage(fgen, fgen_shadow, polyak_decay) num_insts += batch_size nll += nll_batch nent += nent_batch if batch_idx % args.log_interval == 0: sys.stdout.write("\b" * num_back) sys.stdout.write(" " * num_back) sys.stdout.write("\b" * num_back) train_nent = nent / num_insts train_nll = nll / num_insts + train_nent + np.log(n_bins / 2.) * nx bits_per_pixel = train_nll / (nx * np.log(2.0)) nent_per_pixel = train_nent / (nx * np.log(2.0)) curr_lr = scheduler.get_lr()[0] log_info = '[{}/{} ({:.0f}%) lr={:.6f}, {}] NLL: {:.2f}, BPD: {:.4f}, NENT: {:.2f}, NEPD: {:.4f}'.format( batch_idx * batch_size, len(train_index), 100. * batch_idx * batch_size / len(train_index), curr_lr, num_nans, train_nll, bits_per_pixel, train_nent, nent_per_pixel) sys.stdout.write(log_info) sys.stdout.flush() num_back = len(log_info) sys.stdout.write("\b" * num_back) sys.stdout.write(" " * num_back) sys.stdout.write("\b" * num_back) train_nent = nent / num_insts train_nll = nll / num_insts + train_nent + np.log(n_bins / 2.) * nx bits_per_pixel = train_nll / (nx * np.log(2.0)) nent_per_pixel = train_nent / (nx * np.log(2.0)) print('Average NLL: {:.2f}, BPD: {:.4f}, NENT: {:.2f}, NEPD: {:.4f}, time: {:.1f}s'.format( train_nll, bits_per_pixel, train_nent, nent_per_pixel, time.time() - start_time))
json.dump(params, open(os.path.join(model_path, 'config.json'), 'w'), indent=2) if dequant == 'uniform': fgen = FlowGenModel.from_params(params).to(device) elif dequant == 'variational': fgen = VDeQuantFlowGenModel.from_params(params).to(device) else: raise ValueError('unknown dequantization method: %s' % dequant) # initialize fgen.eval() init_batch_size = 512 init_iter = 1 print('init: {} instances with {} iterations'.format(init_batch_size, init_iter)) for _ in range(init_iter): init_index = np.random.choice(train_index, init_batch_size, replace=False) init_data, _ = get_batch(train_data, init_index) init_data = preprocess(init_data.to(device), n_bits) fgen.init(init_data, init_scale=1.0) # create shadow mae for ema # params = json.load(open(args.config, 'r')) # fgen_shadow = FlowGenModel.from_params(params).to(device) # exponentialMovingAverage(fgen, fgen_shadow, polyak_decay, init=True) fgen.to_device(device) optimizer = get_optimizer(lr, fgen.parameters()) lmbda = lambda step: step / float(warmups) if step < warmups else step_decay ** (step - warmups) scheduler = optim.lr_scheduler.LambdaLR(optimizer, lmbda) scheduler.step() start_epoch = 1 patient = 0 best_epoch = 0