def eval_ae_mask_channels(epoch, wt_model, model, sample_loader, args,
                          img_output_dir, model_dir):
    with torch.no_grad():
        model.eval()

        for data in sample_loader:
            data = data.to(model.device)

            # Get Y
            Y = wt_model(data)

            # Zeroing out first patch
            Y = zero_mask(Y, num_iwt=args.num_wt, cur_iwt=1)
            if args.num_wt == 1:
                Y = hf_collate_to_channels(Y, device=model.device)
            elif args.num_wt == 2:
                Y = hf_collate_to_channels_wt2(Y, device=model.device)

            x_hat = model(Y.to(model.device))
            x_hat = hf_collate_to_img(x_hat)
            Y = hf_collate_to_img(Y)

            save_image(x_hat.cpu(),
                       img_output_dir + '/sample_recon{}.png'.format(epoch))
            save_image(Y.cpu(), img_output_dir + '/sample{}.png'.format(epoch))

    torch.save(model.state_dict(),
               model_dir + '/aemask_epoch{}.pth'.format(epoch))
Example #2
0
def train_ae_mask_channels(epoch, wt_model, model, criterion, optimizer,
                           train_loader, train_losses, args, writer):
    # toggle model to train mode
    model.train()
    train_loss = 0

    for batch_idx, data in enumerate(train_loader):

        data = data.to(model.device)

        optimizer.zero_grad()

        # Get Y
        Y = wt_model(data)

        # Zeroing out first patch
        Y = zero_mask(Y, num_iwt=args.num_wt, cur_iwt=1)
        if args.num_wt == 1:
            Y = hf_collate_to_channels(Y, device=model.device)
        elif args.num_wt == 2:
            Y = hf_collate_to_channels_wt2(Y, device=model.device)

        x_hat = model(Y)
        loss = model.loss_function(Y, x_hat, criterion)
        loss.backward()

        # Calculating and printing gradient norm
        total_norm = calc_grad_norm_2(model)

        # Calculating and printing gradient norm
        global log_idx
        writer.add_scalar('Loss', loss, log_idx)
        writer.add_scalar('Gradient_norm/before', total_norm, log_idx)
        log_idx += 1

        # Gradient clipping
        if args.grad_clip > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_norm=args.grad_clip,
                                           norm_type=2)
            # Re-calculating total norm after gradient clipping
            total_norm = calc_grad_norm_2(model)
            writer.add_scalar('Gradient_norm/clipped', total_norm, log_idx)

        train_losses.append(loss.cpu().item())
        train_loss += loss

        optimizer.step()
        if batch_idx % args.log_interval == 0:
            logging.info(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss / len(data)))

            n = min(data.size(0), 8)

    logging.info('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))
def eval_iwtae_iwtmask128(epoch,
                          wt_model,
                          iwt_model,
                          optimizer,
                          iwt_fn,
                          sample_loader,
                          args,
                          img_output_dir,
                          model_dir,
                          writer,
                          save=True):
    with torch.no_grad():
        iwt_model.eval()

        for data in sample_loader:
            data = data.to(wt_model.device)

            # Applying WT to X to get Y
            Y = wt_model(data)
            Y = Y[:, :, :128, :128]

            # Zeroing out first patch, if given zero arg
            Y_mask = zero_mask(Y, args.num_iwt, 1)
            # IWT all the leftover high frequencies
            Y_mask = iwt_fn(Y_mask)

            # Getting IWT of only first patch
            Y_low = zero_patches(Y, args.num_iwt)
            Y_low = iwt_fn(Y_low)

            # Run model to get mask (zero out first patch of mask) and x_wt_hat
            mask, _, _ = iwt_model(Y_low)

            # Add first patch to WT'ed mask
            mask_wt = wt_model(mask)
            inner_dim = Y.shape[2] // np.power(2, args.num_iwt)
            mask_wt[:, :, :inner_dim, :inner_dim] += Y[:, :, :inner_dim, :
                                                       inner_dim]

            img_recon = iwt_fn(mask_wt)

            # Save images
            save_image(Y_low.cpu(), img_output_dir + '/y{}.png'.format(epoch))
            save_image(mask.cpu(),
                       img_output_dir + '/recon_mask{}.png'.format(epoch))
            save_image(Y_mask.cpu(),
                       img_output_dir + '/mask{}.png'.format(epoch))
            save_image(img_recon.cpu(),
                       img_output_dir + '/recon_img{}.png'.format(epoch))
            save_image(data.cpu(), img_output_dir + '/img{}.png'.format(epoch))

    if save:
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': iwt_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, model_dir + '/iwtvae_epoch{}.pth'.format(epoch))
Example #4
0
def train_iwtae_iwtmask(epoch, wt_model, iwt_model, optimizer, iwt_fn,
                        train_loader, train_losses, args, writer):
    # toggle model to train mode
    iwt_model.train()
    train_loss = 0

    for batch_idx, data in enumerate(train_loader):

        data = data.to(wt_model.device)

        optimizer.zero_grad()

        # Get Y
        Y = wt_model(data)

        # Zeroing out first patch, if given zero arg
        Y_mask = zero_mask(Y, args.num_iwt, 1)
        # IWT all the leftover high frequencies
        Y_mask = iwt_fn(Y_mask)

        # Getting IWT of only first patch
        Y_low = zero_patches(Y, args.num_iwt)
        Y_low = iwt_fn(Y_low)

        # Run model to get mask (zero out first patch of mask) and x_wt_hat
        mask, mu, var = iwt_model(Y_low)

        loss, loss_bce, loss_kld = iwt_model.loss_function(
            Y_mask, mask, mu, var)
        loss.backward()

        # Calculating and printing gradient norm
        total_norm = calc_grad_norm_2(iwt_model)

        # Calculating and printing gradient norm
        global log_idx
        writer.add_scalar('Loss/total', loss, log_idx)
        writer.add_scalar('Loss/bce', loss_bce, log_idx)
        writer.add_scalar('Loss/kld', loss_kld, log_idx)
        writer.add_scalar('Gradient_norm/before', total_norm, log_idx)
        log_idx += 1

        # Gradient clipping
        if args.grad_clip > 0:
            torch.nn.utils.clip_grad_norm_(iwt_model.parameters(),
                                           max_norm=args.grad_clip,
                                           norm_type=2)
            total_norm = calc_grad_norm_2(iwt_model)
            writer.add_scalar('Gradient_norm/clipped', total_norm, log_idx)

        train_losses.append(
            [loss.cpu().item(),
             loss_bce.cpu().item(),
             loss_kld.cpu().item()])
        train_loss += loss

        optimizer.step()

        # Logging
        if batch_idx % args.log_interval == 0:
            logging.info(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss / len(data)))

            n = min(data.size(0), 8)

    logging.info('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))
Example #5
0
def train_iwtvae(epoch, wt_model, iwt_model, optimizer, iwt_fn, train_loader,
                 train_losses, args, writer):
    # toggle model to train mode
    iwt_model.train()
    train_loss = 0

    # iwt_fn = IWT(iwt=iwt, num_iwt=self.num_iwt)
    # iwt_fn.set_filters(filters)

    for batch_idx, data in enumerate(train_loader):

        data0 = data.to(iwt_model.device)
        data1 = data.to(wt_model.device)

        optimizer.zero_grad()

        # Get Y
        Y = wt_model(data1)

        # Zeroing out all other patches, if given zero arg
        Y_full = Y.clone()
        if args.zero:
            Y = zero_patches(Y, num_wt=args.num_iwt)

        # Run model to get mask (zero out first patch of mask) and x_wt_hat
        mask, mu, var = iwt_model(data0, Y_full.to(iwt_model.device),
                                  Y.to(iwt_model.device))
        with torch.no_grad():
            mask = zero_mask(mask, args.num_iwt, 1)
            assert (mask[:, :, :128, :128] == 0).all()

        # Y only has first patch + mask
        x_wt_hat = Y + mask
        x_hat = iwt_fn(x_wt_hat)

        # Get x_wt, assuming deterministic WT model/function, and fill 0's in first patch
        x_wt = wt_model(data0)
        x_wt = zero_mask(x_wt, args.num_iwt, 1)

        # Calculate loss
        img_loss = (epoch >= args.img_loss_epoch)
        loss, loss_bce, loss_kld = iwt_model.loss_function(
            data0,
            x_hat,
            x_wt,
            x_wt_hat,
            mu,
            var,
            img_loss,
            kl_weight=args.kl_weight)
        loss.backward()

        # Calculating and printing gradient norm
        total_norm = calc_grad_norm_2(iwt_model)

        # Calculating and printing gradient norm
        global log_idx
        writer.add_scalar('Loss/total', loss, log_idx)
        writer.add_scalar('Loss/bce', loss_bce, log_idx)
        writer.add_scalar('Loss/kld', loss_kld, log_idx)
        writer.add_scalar('Gradient_norm/before', total_norm, log_idx)
        writer.add_scalar('KL_weight', args.kl_weight, log_idx)
        log_idx += 1

        # Gradient clipping
        if args.grad_clip > 0:
            torch.nn.utils.clip_grad_norm_(iwt_model.parameters(),
                                           max_norm=args.grad_clip,
                                           norm_type=2)
            total_norm = calc_grad_norm_2(iwt_model)
            writer.add_scalar('Gradient_norm/clipped', total_norm, log_idx)

        train_losses.append(
            [loss.cpu().item(),
             loss_bce.cpu().item(),
             loss_kld.cpu().item()])
        train_loss += loss

        optimizer.step()

        # Logging
        if batch_idx % args.log_interval == 0:
            logging.info(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss / len(data)))

            n = min(data.size(0), 8)

    logging.info('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))
def eval_iwtvae(epoch, wt_model, iwt_model, optimizer, iwt_fn, sample_loader,
                args, img_output_dir, model_dir, writer):
    with torch.no_grad():
        iwt_model.eval()

        for data in sample_loader:
            data = data.to(wt_model.device)

            # Applying WT to X to get Y
            Y = wt_model(data)
            save_image(
                Y.cpu(),
                img_output_dir + '/sample_y_before_zero{}.png'.format(epoch))
            Y_full = Y.clone()

            # Zero-ing out rest of the patches
            if args.zero:
                Y = zero_patches(Y, num_wt=args.num_iwt)

            # Get sample
            z_sample = torch.randn(data.shape[0],
                                   args.z_dim).to(iwt_model.device)

            # Encoder
            mu, var = iwt_model.encode(Y_full - Y)

            # Decoder -- two versions, real z and asmple z
            mask = iwt_model.decode(Y, mu)
            mask = zero_mask(mask, args.num_iwt, 1)
            assert (mask[:, :, :128, :128] == 0).all()

            mask_sample = iwt_model.decode(Y, z_sample)
            mask_sample = zero_mask(mask_sample, args.num_iwt, 1)
            assert (mask_sample[:, :, :128, :128] == 0).all()

            # Construct x_wt_hat and x_wt_hat_sample and apply IWT to get reconstructed and sampled images
            x_wt_hat = Y + mask
            x_wt_hat_sample = Y + mask_sample

            x_hat = iwt_fn(x_wt_hat)
            x_sample = iwt_fn(x_wt_hat_sample)

            # Save images
            save_image(x_hat.cpu(),
                       img_output_dir + '/recon_x{}.png'.format(epoch))
            save_image(x_sample.cpu(),
                       img_output_dir + '/sample_x{}.png'.format(epoch))
            save_image(x_wt_hat.cpu(),
                       img_output_dir + '/recon_x_wt{}.png'.format(epoch))
            save_image(x_wt_hat_sample.cpu(),
                       img_output_dir + '/sample_x_wt{}.png'.format(epoch))
            save_image((Y_full - Y).cpu(),
                       img_output_dir + '/encoder_input{}.png'.format(epoch))
            save_image(Y.cpu(), img_output_dir + '/y{}.png'.format(epoch))
            save_image(data.cpu(),
                       img_output_dir + '/target{}.png'.format(epoch))

    torch.save(
        {
            'epoch': epoch,
            'model_state_dict': iwt_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, model_dir + '/iwtvae_epoch{}.pth'.format(epoch))
def eval_iwtvae_iwtmask(epoch, wt_model, iwt_model, optimizer, iwt_fn,
                        sample_loader, args, img_output_dir, model_dir,
                        writer):
    with torch.no_grad():
        iwt_model.eval()

        for data in sample_loader:
            data = data.to(wt_model.device)

            # Applying WT to X to get Y
            Y = wt_model(data)
            Y_full = Y.clone()

            # Zeroing out first patch
            Y = zero_mask(Y, args.num_iwt, 1)

            # IWT all the leftover high frequencies
            Y = iwt_fn(Y)

            # Get sample
            z_sample = torch.randn(data.shape[0],
                                   args.z_dim).to(iwt_model.device)

            # Encoder
            mu, var = iwt_model.encode(Y)

            # Decoder -- two versions, real z and asmple z
            mask = iwt_model.decode(mu)
            mask_sample = iwt_model.decode(z_sample)

            mask_wt = wt_model(mask)
            mask_sample_wt = wt_model(mask_sample)

            mask_wt[:, :, :128, :128] += Y_full[:, :, :128, :128]
            mask_sample_wt[:, :, :128, :128] += Y_full[:, :, :128, :128]
            padded = torch.zeros(Y.shape, device=Y_full.device)
            padded[:, :, :128, :128] = Y_full[:, :, :128, :128]

            img_low = iwt_fn(padded)
            img_recon = iwt_fn(mask_wt)
            img_sample_recon = iwt_fn(mask_sample_wt)

            # Save images
            save_image(Y.cpu(), img_output_dir + '/y{}.png'.format(epoch))
            save_image(mask.cpu(),
                       img_output_dir + '/recon_y{}.png'.format(epoch))
            save_image(mask_sample.cpu(),
                       img_output_dir + '/sample_y{}.png'.format(epoch))
            save_image(img_low.cpu(),
                       img_output_dir + '/low_img{}.png'.format(epoch))
            save_image(img_recon.cpu(),
                       img_output_dir + '/recon_img{}.png'.format(epoch))
            save_image(
                img_sample_recon.cpu(),
                img_output_dir + '/recon_sample_img{}.png'.format(epoch))
            save_image(data.cpu(),
                       img_output_dir + '/target{}.png'.format(epoch))

    torch.save(
        {
            'epoch': epoch,
            'model_state_dict': iwt_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, model_dir + '/iwtvae_epoch{}.pth'.format(epoch))
            y = wt_model.decode(z)
            y_sample1 = wt_model.decode(z_sample1)
            y_sample2 = wt_model.decode(z_sample2)

            y_padded = zero_pad(y, target_dim=512, device=device)
            y_sample_padded1 = zero_pad(y_sample1,
                                        target_dim=512,
                                        device=device)
            y_sample_padded2 = zero_pad(y_sample2,
                                        target_dim=512,
                                        device=device)

            data512_wt = wt_fn(data512)
            # Zero out first patch and apply IWT
            data512_mask = zero_mask(data512_wt, args.num_iwt, 1)
            data512_mask = iwt_fn(data512_mask)

            mask, mu, var = iwt_model(data512_mask)

            mask_wt = wt_fn(mask)

            img_low = iwt_fn(y_padded)
            img_low_sample1 = iwt_fn(y_sample_padded1)
            img_low_sample2 = iwt_fn(y_sample_padded2)

            img_recon = iwt_fn(y_padded + mask_wt)
            img_sample1_recon = iwt_fn(y_sample_padded1 + mask_wt)
            img_sample2_recon = iwt_fn(y_sample_padded2 + mask_wt)

            # Save images