def create_approximate_posterior(): if args.approximate_posterior_type == 'diagonal-normal': context_encoder = nn_.ConvEncoder( context_features=args.context_features, channels_multiplier=16, dropout_probability=args.dropout_probability_encoder_decoder) approximate_posterior = distributions_.ConditionalDiagonalNormal( shape=[args.latent_features], context_encoder=context_encoder) else: context_encoder = nn.Linear(args.context_features, 2 * args.latent_features) distribution = distributions_.ConditionalDiagonalNormal( shape=[args.latent_features], context_encoder=context_encoder) transform = transforms.CompositeTransform([ transforms.CompositeTransform([ create_linear_transform(), create_base_transform( i, context_features=args.context_features) ]) for i in range(args.num_flow_steps) ]) transform = transforms.CompositeTransform( [transform, create_linear_transform()]) approximate_posterior = flows.Flow( transforms.InverseTransform(transform), distribution) return approximate_posterior
def eval_reconstruct(num_bits, batch_size, seed, num_reconstruct_batches, _log, output_path=''): torch.set_grad_enabled(False) device = set_device() torch.manual_seed(seed) np.random.seed(seed) train_dataset, _, (c, h, w) = get_train_valid_data() flow = create_flow(c, h, w).to(device) flow.eval() train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) identity_transform = transforms.CompositeTransform( [flow._transform, transforms.InverseTransform(flow._transform)]) first_batch = True abs_diff = [] for batch, _ in tqdm(load_num_batches(train_loader, num_reconstruct_batches), total=num_reconstruct_batches): batch = batch.to(device) batch_rec, _ = identity_transform(batch) abs_diff.append(torch.abs(batch_rec - batch)) if first_batch: batch = Preprocess(num_bits).inverse(batch[:36, ...]) batch_rec = Preprocess(num_bits).inverse(batch_rec[:36, ...]) save_image(batch.cpu(), os.path.join(output_path, 'invertibility_orig.png'), nrow=6, padding=0) save_image(batch_rec.cpu(), os.path.join(output_path, 'invertibility_rec.png'), nrow=6, padding=0) first_batch = False abs_diff = torch.cat(abs_diff) print('max abs diff: {:.4f}'.format(torch.max(abs_diff).item()))
def train_flow(flow, train_dataset, val_dataset, dataset_dims, device, batch_size, num_steps, learning_rate, cosine_annealing, warmup_fraction, temperatures, num_bits, num_workers, intervals, multi_gpu, actnorm, optimizer_checkpoint, start_step, eta_min, _log): run_dir = fso.dir flow = flow.to(device) summary_writer = SummaryWriter(run_dir, max_queue=100) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=num_workers) if val_dataset: val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, num_workers=num_workers) else: val_loader = None # Random batch and identity transform for reconstruction evaluation. random_batch, _ = next( iter( DataLoader( dataset=train_dataset, batch_size=batch_size, num_workers= 0 # Faster than starting all workers just to get a single batch. ))) identity_transform = transforms.CompositeTransform( [flow._transform, transforms.InverseTransform(flow._transform)]) optimizer = torch.optim.Adam(flow.parameters(), lr=learning_rate) if optimizer_checkpoint is not None: optimizer.load_state_dict(torch.load(optimizer_checkpoint)) _log.info( 'Optimizer state loaded from {}'.format(optimizer_checkpoint)) if cosine_annealing: if warmup_fraction == 0.: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer=optimizer, T_max=num_steps, last_epoch=-1 if start_step == 0 else start_step, eta_min=eta_min) else: scheduler = optim.CosineAnnealingWarmUpLR( optimizer=optimizer, warm_up_epochs=int(warmup_fraction * num_steps), total_epochs=num_steps, last_epoch=-1 if start_step == 0 else start_step, eta_min=eta_min) else: scheduler = None def nats_to_bits_per_dim(x): c, h, w = dataset_dims return autils.nats_to_bits_per_dim(x, c, h, w) _log.info('Starting training...') best_val_log_prob = None start_time = None num_batches = num_steps - start_step for step, (batch, _) in enumerate(load_num_batches(loader=train_loader, num_batches=num_batches), start=start_step): if step == 0: start_time = time.time( ) # Runtime estimate will be more accurate if set here. flow.train() optimizer.zero_grad() batch = batch.to(device) if multi_gpu: if actnorm and step == 0: # Is using actnorm, data-dependent initialization doesn't work with data_parallel, # so pass a single batch on a single GPU before the first step. flow.log_prob(batch[:batch.shape[0] // torch.cuda.device_count(), ...]) # Split along the batch dimension and put each split on a separate GPU. All available # GPUs are used. log_density = nn.parallel.data_parallel(LogProbWrapper(flow), batch) else: log_density = flow.log_prob(batch) loss = -nats_to_bits_per_dim(torch.mean(log_density)) loss.backward() optimizer.step() if scheduler is not None: scheduler.step() summary_writer.add_scalar('learning_rate', scheduler.get_lr()[0], step) summary_writer.add_scalar('loss', loss.item(), step) if best_val_log_prob: summary_writer.add_scalar('best_val_log_prob', best_val_log_prob, step) flow.eval() # Everything beyond this point is evaluation. if step % intervals['log'] == 0: elapsed_time = time.time() - start_time progress = autils.progress_string(elapsed_time, step, num_steps) _log.info("It: {}/{} loss: {:.3f} [{}]".format( step, num_steps, loss, progress)) if step % intervals['sample'] == 0: fig, axs = plt.subplots(1, len(temperatures), figsize=(4 * len(temperatures), 4)) for temperature, ax in zip(temperatures, axs.flat): with torch.no_grad(): noise = flow._distribution.sample(64) * temperature samples, _ = flow._transform.inverse(noise) samples = Preprocess(num_bits).inverse(samples) autils.imshow(make_grid(samples, nrow=8), ax) ax.set_title('T={:.2f}'.format(temperature)) summary_writer.add_figure(tag='samples', figure=fig, global_step=step) plt.close(fig) if step > 0 and step % intervals['eval'] == 0 and (val_loader is not None): if multi_gpu: def log_prob_fn(batch): return nn.parallel.data_parallel(LogProbWrapper(flow), batch.to(device)) else: def log_prob_fn(batch): return flow.log_prob(batch.to(device)) val_log_prob = autils.eval_log_density(log_prob_fn=log_prob_fn, data_loader=val_loader) val_log_prob = nats_to_bits_per_dim(val_log_prob).item() _log.info("It: {}/{} val_log_prob: {:.3f}".format( step, num_steps, val_log_prob)) summary_writer.add_scalar('val_log_prob', val_log_prob, step) if best_val_log_prob is None or val_log_prob > best_val_log_prob: best_val_log_prob = val_log_prob torch.save(flow.state_dict(), os.path.join(run_dir, 'flow_best.pt')) _log.info( 'It: {}/{} best val_log_prob improved, saved flow_best.pt'. format(step, num_steps)) if step > 0 and (step % intervals['save'] == 0 or step == (num_steps - 1)): torch.save(optimizer.state_dict(), os.path.join(run_dir, 'optimizer_last.pt')) torch.save(flow.state_dict(), os.path.join(run_dir, 'flow_last.pt')) _log.info( 'It: {}/{} saved optimizer_last.pt and flow_last.pt'.format( step, num_steps)) if step > 0 and step % intervals['reconstruct'] == 0: with torch.no_grad(): random_batch_ = random_batch.to(device) random_batch_rec, logabsdet = identity_transform(random_batch_) max_abs_diff = torch.max( torch.abs(random_batch_rec - random_batch_)) max_logabsdet = torch.max(logabsdet) # fig, axs = plt.subplots(1, 2, figsize=(8, 4)) # autils.imshow(make_grid(Preprocess(num_bits).inverse(random_batch[:36, ...]), # nrow=6), axs[0]) # autils.imshow(make_grid(Preprocess(num_bits).inverse(random_batch_rec[:36, ...]), # nrow=6), axs[1]) # summary_writer.add_figure(tag='reconstr', figure=fig, global_step=step) # plt.close(fig) summary_writer.add_scalar(tag='max_reconstr_abs_diff', scalar_value=max_abs_diff.item(), global_step=step) summary_writer.add_scalar(tag='max_reconstr_logabsdet', scalar_value=max_logabsdet.item(), global_step=step)
def __init__(self, squashing_transform, cdf_transform): super().__init__([ squashing_transform, cdf_transform, transforms.InverseTransform(squashing_transform) ])