def eval_iwtae_iwtmask128(epoch,
                          wt_model,
                          iwt_model,
                          optimizer,
                          iwt_fn,
                          sample_loader,
                          args,
                          img_output_dir,
                          model_dir,
                          writer,
                          save=True):
    with torch.no_grad():
        iwt_model.eval()

        for data in sample_loader:
            data = data.to(wt_model.device)

            # Applying WT to X to get Y
            Y = wt_model(data)
            Y = Y[:, :, :128, :128]

            # Zeroing out first patch, if given zero arg
            Y_mask = zero_mask(Y, args.num_iwt, 1)
            # IWT all the leftover high frequencies
            Y_mask = iwt_fn(Y_mask)

            # Getting IWT of only first patch
            Y_low = zero_patches(Y, args.num_iwt)
            Y_low = iwt_fn(Y_low)

            # Run model to get mask (zero out first patch of mask) and x_wt_hat
            mask, _, _ = iwt_model(Y_low)

            # Add first patch to WT'ed mask
            mask_wt = wt_model(mask)
            inner_dim = Y.shape[2] // np.power(2, args.num_iwt)
            mask_wt[:, :, :inner_dim, :inner_dim] += Y[:, :, :inner_dim, :
                                                       inner_dim]

            img_recon = iwt_fn(mask_wt)

            # Save images
            save_image(Y_low.cpu(), img_output_dir + '/y{}.png'.format(epoch))
            save_image(mask.cpu(),
                       img_output_dir + '/recon_mask{}.png'.format(epoch))
            save_image(Y_mask.cpu(),
                       img_output_dir + '/mask{}.png'.format(epoch))
            save_image(img_recon.cpu(),
                       img_output_dir + '/recon_img{}.png'.format(epoch))
            save_image(data.cpu(), img_output_dir + '/img{}.png'.format(epoch))

    if save:
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': iwt_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, model_dir + '/iwtvae_epoch{}.pth'.format(epoch))
Exemple #2
0
def train_iwtae_iwtmask(epoch, wt_model, iwt_model, optimizer, iwt_fn,
                        train_loader, train_losses, args, writer):
    # toggle model to train mode
    iwt_model.train()
    train_loss = 0

    for batch_idx, data in enumerate(train_loader):

        data = data.to(wt_model.device)

        optimizer.zero_grad()

        # Get Y
        Y = wt_model(data)

        # Zeroing out first patch, if given zero arg
        Y_mask = zero_mask(Y, args.num_iwt, 1)
        # IWT all the leftover high frequencies
        Y_mask = iwt_fn(Y_mask)

        # Getting IWT of only first patch
        Y_low = zero_patches(Y, args.num_iwt)
        Y_low = iwt_fn(Y_low)

        # Run model to get mask (zero out first patch of mask) and x_wt_hat
        mask, mu, var = iwt_model(Y_low)

        loss, loss_bce, loss_kld = iwt_model.loss_function(
            Y_mask, mask, mu, var)
        loss.backward()

        # Calculating and printing gradient norm
        total_norm = calc_grad_norm_2(iwt_model)

        # Calculating and printing gradient norm
        global log_idx
        writer.add_scalar('Loss/total', loss, log_idx)
        writer.add_scalar('Loss/bce', loss_bce, log_idx)
        writer.add_scalar('Loss/kld', loss_kld, log_idx)
        writer.add_scalar('Gradient_norm/before', total_norm, log_idx)
        log_idx += 1

        # Gradient clipping
        if args.grad_clip > 0:
            torch.nn.utils.clip_grad_norm_(iwt_model.parameters(),
                                           max_norm=args.grad_clip,
                                           norm_type=2)
            total_norm = calc_grad_norm_2(iwt_model)
            writer.add_scalar('Gradient_norm/clipped', total_norm, log_idx)

        train_losses.append(
            [loss.cpu().item(),
             loss_bce.cpu().item(),
             loss_kld.cpu().item()])
        train_loss += loss

        optimizer.step()

        # Logging
        if batch_idx % args.log_interval == 0:
            logging.info(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss / len(data)))

            n = min(data.size(0), 8)

    logging.info('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))
Exemple #3
0
def train_iwtvae(epoch, wt_model, iwt_model, optimizer, iwt_fn, train_loader,
                 train_losses, args, writer):
    # toggle model to train mode
    iwt_model.train()
    train_loss = 0

    # iwt_fn = IWT(iwt=iwt, num_iwt=self.num_iwt)
    # iwt_fn.set_filters(filters)

    for batch_idx, data in enumerate(train_loader):

        data0 = data.to(iwt_model.device)
        data1 = data.to(wt_model.device)

        optimizer.zero_grad()

        # Get Y
        Y = wt_model(data1)

        # Zeroing out all other patches, if given zero arg
        Y_full = Y.clone()
        if args.zero:
            Y = zero_patches(Y, num_wt=args.num_iwt)

        # Run model to get mask (zero out first patch of mask) and x_wt_hat
        mask, mu, var = iwt_model(data0, Y_full.to(iwt_model.device),
                                  Y.to(iwt_model.device))
        with torch.no_grad():
            mask = zero_mask(mask, args.num_iwt, 1)
            assert (mask[:, :, :128, :128] == 0).all()

        # Y only has first patch + mask
        x_wt_hat = Y + mask
        x_hat = iwt_fn(x_wt_hat)

        # Get x_wt, assuming deterministic WT model/function, and fill 0's in first patch
        x_wt = wt_model(data0)
        x_wt = zero_mask(x_wt, args.num_iwt, 1)

        # Calculate loss
        img_loss = (epoch >= args.img_loss_epoch)
        loss, loss_bce, loss_kld = iwt_model.loss_function(
            data0,
            x_hat,
            x_wt,
            x_wt_hat,
            mu,
            var,
            img_loss,
            kl_weight=args.kl_weight)
        loss.backward()

        # Calculating and printing gradient norm
        total_norm = calc_grad_norm_2(iwt_model)

        # Calculating and printing gradient norm
        global log_idx
        writer.add_scalar('Loss/total', loss, log_idx)
        writer.add_scalar('Loss/bce', loss_bce, log_idx)
        writer.add_scalar('Loss/kld', loss_kld, log_idx)
        writer.add_scalar('Gradient_norm/before', total_norm, log_idx)
        writer.add_scalar('KL_weight', args.kl_weight, log_idx)
        log_idx += 1

        # Gradient clipping
        if args.grad_clip > 0:
            torch.nn.utils.clip_grad_norm_(iwt_model.parameters(),
                                           max_norm=args.grad_clip,
                                           norm_type=2)
            total_norm = calc_grad_norm_2(iwt_model)
            writer.add_scalar('Gradient_norm/clipped', total_norm, log_idx)

        train_losses.append(
            [loss.cpu().item(),
             loss_bce.cpu().item(),
             loss_kld.cpu().item()])
        train_loss += loss

        optimizer.step()

        # Logging
        if batch_idx % args.log_interval == 0:
            logging.info(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss / len(data)))

            n = min(data.size(0), 8)

    logging.info('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))
def eval_iwtvae(epoch, wt_model, iwt_model, optimizer, iwt_fn, sample_loader,
                args, img_output_dir, model_dir, writer):
    with torch.no_grad():
        iwt_model.eval()

        for data in sample_loader:
            data = data.to(wt_model.device)

            # Applying WT to X to get Y
            Y = wt_model(data)
            save_image(
                Y.cpu(),
                img_output_dir + '/sample_y_before_zero{}.png'.format(epoch))
            Y_full = Y.clone()

            # Zero-ing out rest of the patches
            if args.zero:
                Y = zero_patches(Y, num_wt=args.num_iwt)

            # Get sample
            z_sample = torch.randn(data.shape[0],
                                   args.z_dim).to(iwt_model.device)

            # Encoder
            mu, var = iwt_model.encode(Y_full - Y)

            # Decoder -- two versions, real z and asmple z
            mask = iwt_model.decode(Y, mu)
            mask = zero_mask(mask, args.num_iwt, 1)
            assert (mask[:, :, :128, :128] == 0).all()

            mask_sample = iwt_model.decode(Y, z_sample)
            mask_sample = zero_mask(mask_sample, args.num_iwt, 1)
            assert (mask_sample[:, :, :128, :128] == 0).all()

            # Construct x_wt_hat and x_wt_hat_sample and apply IWT to get reconstructed and sampled images
            x_wt_hat = Y + mask
            x_wt_hat_sample = Y + mask_sample

            x_hat = iwt_fn(x_wt_hat)
            x_sample = iwt_fn(x_wt_hat_sample)

            # Save images
            save_image(x_hat.cpu(),
                       img_output_dir + '/recon_x{}.png'.format(epoch))
            save_image(x_sample.cpu(),
                       img_output_dir + '/sample_x{}.png'.format(epoch))
            save_image(x_wt_hat.cpu(),
                       img_output_dir + '/recon_x_wt{}.png'.format(epoch))
            save_image(x_wt_hat_sample.cpu(),
                       img_output_dir + '/sample_x_wt{}.png'.format(epoch))
            save_image((Y_full - Y).cpu(),
                       img_output_dir + '/encoder_input{}.png'.format(epoch))
            save_image(Y.cpu(), img_output_dir + '/y{}.png'.format(epoch))
            save_image(data.cpu(),
                       img_output_dir + '/target{}.png'.format(epoch))

    torch.save(
        {
            'epoch': epoch,
            'model_state_dict': iwt_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, model_dir + '/iwtvae_epoch{}.pth'.format(epoch))
Exemple #5
0
    for epoch in range(1, args.epochs + 1):
        train_iwtvae(epoch, wt_model, iwt_model, optimizer, train_loader,
                     train_losses, args)

        with torch.no_grad():
            iwt_model.eval()

            for data in sample_loader:
                data0 = data.to(devices[0])
                data1 = data.to(devices[1])

                z_sample = torch.randn(data.shape[0], 100).to(devices[0])

                Y = wt_model(data1)[0]
                if args.zero:
                    Y = zero_patches(Y)
                Y = Y.to(devices[0])

                mu, var = iwt_model.encode(data0, Y)
                x_hat = iwt_model.decode(Y, mu)
                x_sample = iwt_model.decode(Y, z_sample)

                save_image(
                    x_hat.cpu(),
                    img_output_dir + '/sample_recon{}.png'.format(epoch))
                save_image(x_sample.cpu(),
                           img_output_dir + '/sample_z{}.png'.format(epoch))
                save_image(Y.cpu(),
                           img_output_dir + '/sample_y{}.png'.format(epoch))
                save_image(data.cpu(),
                           img_output_dir + '/sample{}.png'.format(epoch))