def reconstruct(epoch): print('reconstruct') fgen.eval() n = 16 np.random.shuffle(test_index) img, _ = get_batch(test_data, test_index[:n]) img = preprocess(img.to(device), n_bits) z, _ = fgen.encode(img) img_recon, _ = fgen.decode(z) abs_err = img_recon.add(img * -1).abs() print('Err: {:.4f}, {:.4f}'.format(abs_err.max().item(), abs_err.mean().item())) img = postprocess(img, n_bits) img_recon = postprocess(img_recon, n_bits) comparison = torch.cat([img, img_recon], dim=0).cpu() reorder_index = torch.from_numpy(np.array([[i + j * n for j in range(2)] for i in range(n)])).view(-1) comparison = comparison[reorder_index] image_file = 'reconstruct{}.png'.format(epoch) save_image(comparison, os.path.join(result_path, image_file), nrow=16)
params = json.load(open(args.config, 'r')) json.dump(params, open(os.path.join(model_path, 'config.json'), 'w'), indent=2) if dequant == 'uniform': fgen = FlowGenModel.from_params(params).to(device) elif dequant == 'variational': fgen = VDeQuantFlowGenModel.from_params(params).to(device) else: raise ValueError('unknown dequantization method: %s' % dequant) # initialize fgen.eval() init_batch_size = 512 init_iter = 1 print('init: {} instances with {} iterations'.format(init_batch_size, init_iter)) for _ in range(init_iter): init_index = np.random.choice(train_index, init_batch_size, replace=False) init_data, _ = get_batch(train_data, init_index) init_data = preprocess(init_data.to(device), n_bits) fgen.init(init_data, init_scale=1.0) # create shadow mae for ema # params = json.load(open(args.config, 'r')) # fgen_shadow = FlowGenModel.from_params(params).to(device) # exponentialMovingAverage(fgen, fgen_shadow, polyak_decay, init=True) fgen.to_device(device) optimizer = get_optimizer(lr, fgen.parameters()) lmbda = lambda step: step / float(warmups) if step < warmups else step_decay ** (step - warmups) scheduler = optim.lr_scheduler.LambdaLR(optimizer, lmbda) scheduler.step() start_epoch = 1 patient = 0