def eval_ae_mask_channels(epoch, wt_model, model, sample_loader, args, img_output_dir, model_dir): with torch.no_grad(): model.eval() for data in sample_loader: data = data.to(model.device) # Get Y Y = wt_model(data) # Zeroing out first patch Y = zero_mask(Y, num_iwt=args.num_wt, cur_iwt=1) if args.num_wt == 1: Y = hf_collate_to_channels(Y, device=model.device) elif args.num_wt == 2: Y = hf_collate_to_channels_wt2(Y, device=model.device) x_hat = model(Y.to(model.device)) x_hat = hf_collate_to_img(x_hat) Y = hf_collate_to_img(Y) save_image(x_hat.cpu(), img_output_dir + '/sample_recon{}.png'.format(epoch)) save_image(Y.cpu(), img_output_dir + '/sample{}.png'.format(epoch)) torch.save(model.state_dict(), model_dir + '/aemask_epoch{}.pth'.format(epoch))
def train_ae_mask_channels(epoch, wt_model, model, criterion, optimizer, train_loader, train_losses, args, writer): # toggle model to train mode model.train() train_loss = 0 for batch_idx, data in enumerate(train_loader): data = data.to(model.device) optimizer.zero_grad() # Get Y Y = wt_model(data) # Zeroing out first patch Y = zero_mask(Y, num_iwt=args.num_wt, cur_iwt=1) if args.num_wt == 1: Y = hf_collate_to_channels(Y, device=model.device) elif args.num_wt == 2: Y = hf_collate_to_channels_wt2(Y, device=model.device) x_hat = model(Y) loss = model.loss_function(Y, x_hat, criterion) loss.backward() # Calculating and printing gradient norm total_norm = calc_grad_norm_2(model) # Calculating and printing gradient norm global log_idx writer.add_scalar('Loss', loss, log_idx) writer.add_scalar('Gradient_norm/before', total_norm, log_idx) log_idx += 1 # Gradient clipping if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.grad_clip, norm_type=2) # Re-calculating total norm after gradient clipping total_norm = calc_grad_norm_2(model) writer.add_scalar('Gradient_norm/clipped', total_norm, log_idx) train_losses.append(loss.cpu().item()) train_loss += loss optimizer.step() if batch_idx % args.log_interval == 0: logging.info( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss / len(data))) n = min(data.size(0), 8) logging.info('====> Epoch: {} Average loss: {:.4f}'.format( epoch, train_loss / len(train_loader.dataset)))
def eval_iwtae_iwtmask128(epoch, wt_model, iwt_model, optimizer, iwt_fn, sample_loader, args, img_output_dir, model_dir, writer, save=True): with torch.no_grad(): iwt_model.eval() for data in sample_loader: data = data.to(wt_model.device) # Applying WT to X to get Y Y = wt_model(data) Y = Y[:, :, :128, :128] # Zeroing out first patch, if given zero arg Y_mask = zero_mask(Y, args.num_iwt, 1) # IWT all the leftover high frequencies Y_mask = iwt_fn(Y_mask) # Getting IWT of only first patch Y_low = zero_patches(Y, args.num_iwt) Y_low = iwt_fn(Y_low) # Run model to get mask (zero out first patch of mask) and x_wt_hat mask, _, _ = iwt_model(Y_low) # Add first patch to WT'ed mask mask_wt = wt_model(mask) inner_dim = Y.shape[2] // np.power(2, args.num_iwt) mask_wt[:, :, :inner_dim, :inner_dim] += Y[:, :, :inner_dim, : inner_dim] img_recon = iwt_fn(mask_wt) # Save images save_image(Y_low.cpu(), img_output_dir + '/y{}.png'.format(epoch)) save_image(mask.cpu(), img_output_dir + '/recon_mask{}.png'.format(epoch)) save_image(Y_mask.cpu(), img_output_dir + '/mask{}.png'.format(epoch)) save_image(img_recon.cpu(), img_output_dir + '/recon_img{}.png'.format(epoch)) save_image(data.cpu(), img_output_dir + '/img{}.png'.format(epoch)) if save: torch.save( { 'epoch': epoch, 'model_state_dict': iwt_model.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, model_dir + '/iwtvae_epoch{}.pth'.format(epoch))
def train_iwtae_iwtmask(epoch, wt_model, iwt_model, optimizer, iwt_fn, train_loader, train_losses, args, writer): # toggle model to train mode iwt_model.train() train_loss = 0 for batch_idx, data in enumerate(train_loader): data = data.to(wt_model.device) optimizer.zero_grad() # Get Y Y = wt_model(data) # Zeroing out first patch, if given zero arg Y_mask = zero_mask(Y, args.num_iwt, 1) # IWT all the leftover high frequencies Y_mask = iwt_fn(Y_mask) # Getting IWT of only first patch Y_low = zero_patches(Y, args.num_iwt) Y_low = iwt_fn(Y_low) # Run model to get mask (zero out first patch of mask) and x_wt_hat mask, mu, var = iwt_model(Y_low) loss, loss_bce, loss_kld = iwt_model.loss_function( Y_mask, mask, mu, var) loss.backward() # Calculating and printing gradient norm total_norm = calc_grad_norm_2(iwt_model) # Calculating and printing gradient norm global log_idx writer.add_scalar('Loss/total', loss, log_idx) writer.add_scalar('Loss/bce', loss_bce, log_idx) writer.add_scalar('Loss/kld', loss_kld, log_idx) writer.add_scalar('Gradient_norm/before', total_norm, log_idx) log_idx += 1 # Gradient clipping if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(iwt_model.parameters(), max_norm=args.grad_clip, norm_type=2) total_norm = calc_grad_norm_2(iwt_model) writer.add_scalar('Gradient_norm/clipped', total_norm, log_idx) train_losses.append( [loss.cpu().item(), loss_bce.cpu().item(), loss_kld.cpu().item()]) train_loss += loss optimizer.step() # Logging if batch_idx % args.log_interval == 0: logging.info( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss / len(data))) n = min(data.size(0), 8) logging.info('====> Epoch: {} Average loss: {:.4f}'.format( epoch, train_loss / len(train_loader.dataset)))
def train_iwtvae(epoch, wt_model, iwt_model, optimizer, iwt_fn, train_loader, train_losses, args, writer): # toggle model to train mode iwt_model.train() train_loss = 0 # iwt_fn = IWT(iwt=iwt, num_iwt=self.num_iwt) # iwt_fn.set_filters(filters) for batch_idx, data in enumerate(train_loader): data0 = data.to(iwt_model.device) data1 = data.to(wt_model.device) optimizer.zero_grad() # Get Y Y = wt_model(data1) # Zeroing out all other patches, if given zero arg Y_full = Y.clone() if args.zero: Y = zero_patches(Y, num_wt=args.num_iwt) # Run model to get mask (zero out first patch of mask) and x_wt_hat mask, mu, var = iwt_model(data0, Y_full.to(iwt_model.device), Y.to(iwt_model.device)) with torch.no_grad(): mask = zero_mask(mask, args.num_iwt, 1) assert (mask[:, :, :128, :128] == 0).all() # Y only has first patch + mask x_wt_hat = Y + mask x_hat = iwt_fn(x_wt_hat) # Get x_wt, assuming deterministic WT model/function, and fill 0's in first patch x_wt = wt_model(data0) x_wt = zero_mask(x_wt, args.num_iwt, 1) # Calculate loss img_loss = (epoch >= args.img_loss_epoch) loss, loss_bce, loss_kld = iwt_model.loss_function( data0, x_hat, x_wt, x_wt_hat, mu, var, img_loss, kl_weight=args.kl_weight) loss.backward() # Calculating and printing gradient norm total_norm = calc_grad_norm_2(iwt_model) # Calculating and printing gradient norm global log_idx writer.add_scalar('Loss/total', loss, log_idx) writer.add_scalar('Loss/bce', loss_bce, log_idx) writer.add_scalar('Loss/kld', loss_kld, log_idx) writer.add_scalar('Gradient_norm/before', total_norm, log_idx) writer.add_scalar('KL_weight', args.kl_weight, log_idx) log_idx += 1 # Gradient clipping if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(iwt_model.parameters(), max_norm=args.grad_clip, norm_type=2) total_norm = calc_grad_norm_2(iwt_model) writer.add_scalar('Gradient_norm/clipped', total_norm, log_idx) train_losses.append( [loss.cpu().item(), loss_bce.cpu().item(), loss_kld.cpu().item()]) train_loss += loss optimizer.step() # Logging if batch_idx % args.log_interval == 0: logging.info( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss / len(data))) n = min(data.size(0), 8) logging.info('====> Epoch: {} Average loss: {:.4f}'.format( epoch, train_loss / len(train_loader.dataset)))
def eval_iwtvae(epoch, wt_model, iwt_model, optimizer, iwt_fn, sample_loader, args, img_output_dir, model_dir, writer): with torch.no_grad(): iwt_model.eval() for data in sample_loader: data = data.to(wt_model.device) # Applying WT to X to get Y Y = wt_model(data) save_image( Y.cpu(), img_output_dir + '/sample_y_before_zero{}.png'.format(epoch)) Y_full = Y.clone() # Zero-ing out rest of the patches if args.zero: Y = zero_patches(Y, num_wt=args.num_iwt) # Get sample z_sample = torch.randn(data.shape[0], args.z_dim).to(iwt_model.device) # Encoder mu, var = iwt_model.encode(Y_full - Y) # Decoder -- two versions, real z and asmple z mask = iwt_model.decode(Y, mu) mask = zero_mask(mask, args.num_iwt, 1) assert (mask[:, :, :128, :128] == 0).all() mask_sample = iwt_model.decode(Y, z_sample) mask_sample = zero_mask(mask_sample, args.num_iwt, 1) assert (mask_sample[:, :, :128, :128] == 0).all() # Construct x_wt_hat and x_wt_hat_sample and apply IWT to get reconstructed and sampled images x_wt_hat = Y + mask x_wt_hat_sample = Y + mask_sample x_hat = iwt_fn(x_wt_hat) x_sample = iwt_fn(x_wt_hat_sample) # Save images save_image(x_hat.cpu(), img_output_dir + '/recon_x{}.png'.format(epoch)) save_image(x_sample.cpu(), img_output_dir + '/sample_x{}.png'.format(epoch)) save_image(x_wt_hat.cpu(), img_output_dir + '/recon_x_wt{}.png'.format(epoch)) save_image(x_wt_hat_sample.cpu(), img_output_dir + '/sample_x_wt{}.png'.format(epoch)) save_image((Y_full - Y).cpu(), img_output_dir + '/encoder_input{}.png'.format(epoch)) save_image(Y.cpu(), img_output_dir + '/y{}.png'.format(epoch)) save_image(data.cpu(), img_output_dir + '/target{}.png'.format(epoch)) torch.save( { 'epoch': epoch, 'model_state_dict': iwt_model.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, model_dir + '/iwtvae_epoch{}.pth'.format(epoch))
def eval_iwtvae_iwtmask(epoch, wt_model, iwt_model, optimizer, iwt_fn, sample_loader, args, img_output_dir, model_dir, writer): with torch.no_grad(): iwt_model.eval() for data in sample_loader: data = data.to(wt_model.device) # Applying WT to X to get Y Y = wt_model(data) Y_full = Y.clone() # Zeroing out first patch Y = zero_mask(Y, args.num_iwt, 1) # IWT all the leftover high frequencies Y = iwt_fn(Y) # Get sample z_sample = torch.randn(data.shape[0], args.z_dim).to(iwt_model.device) # Encoder mu, var = iwt_model.encode(Y) # Decoder -- two versions, real z and asmple z mask = iwt_model.decode(mu) mask_sample = iwt_model.decode(z_sample) mask_wt = wt_model(mask) mask_sample_wt = wt_model(mask_sample) mask_wt[:, :, :128, :128] += Y_full[:, :, :128, :128] mask_sample_wt[:, :, :128, :128] += Y_full[:, :, :128, :128] padded = torch.zeros(Y.shape, device=Y_full.device) padded[:, :, :128, :128] = Y_full[:, :, :128, :128] img_low = iwt_fn(padded) img_recon = iwt_fn(mask_wt) img_sample_recon = iwt_fn(mask_sample_wt) # Save images save_image(Y.cpu(), img_output_dir + '/y{}.png'.format(epoch)) save_image(mask.cpu(), img_output_dir + '/recon_y{}.png'.format(epoch)) save_image(mask_sample.cpu(), img_output_dir + '/sample_y{}.png'.format(epoch)) save_image(img_low.cpu(), img_output_dir + '/low_img{}.png'.format(epoch)) save_image(img_recon.cpu(), img_output_dir + '/recon_img{}.png'.format(epoch)) save_image( img_sample_recon.cpu(), img_output_dir + '/recon_sample_img{}.png'.format(epoch)) save_image(data.cpu(), img_output_dir + '/target{}.png'.format(epoch)) torch.save( { 'epoch': epoch, 'model_state_dict': iwt_model.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, model_dir + '/iwtvae_epoch{}.pth'.format(epoch))
y = wt_model.decode(z) y_sample1 = wt_model.decode(z_sample1) y_sample2 = wt_model.decode(z_sample2) y_padded = zero_pad(y, target_dim=512, device=device) y_sample_padded1 = zero_pad(y_sample1, target_dim=512, device=device) y_sample_padded2 = zero_pad(y_sample2, target_dim=512, device=device) data512_wt = wt_fn(data512) # Zero out first patch and apply IWT data512_mask = zero_mask(data512_wt, args.num_iwt, 1) data512_mask = iwt_fn(data512_mask) mask, mu, var = iwt_model(data512_mask) mask_wt = wt_fn(mask) img_low = iwt_fn(y_padded) img_low_sample1 = iwt_fn(y_sample_padded1) img_low_sample2 = iwt_fn(y_sample_padded2) img_recon = iwt_fn(y_padded + mask_wt) img_sample1_recon = iwt_fn(y_sample_padded1 + mask_wt) img_sample2_recon = iwt_fn(y_sample_padded2 + mask_wt) # Save images