Exemplo n.º 1
0
def test_pyramid(images):
    parser = get_arguments()
    parser.add_argument('--input_dir',
                        help='input image dir',
                        default='Input/Images')
    #parser.add_argument('--input_name', help='input image name', required=True)
    parser.add_argument('--mode', help='task to be done', default='train')
    opt = parser.parse_args("")
    opt.input_name = 'blank'
    opt = functions.post_config(opt)

    real = functions.np2torch(images[0], opt)
    functions.adjust_scales2image(real, opt)

    all_reals = []
    for image in images:
        reals = []
        real_ = functions.np2torch(image, opt)
        real = imresize(real_, opt.scale1, opt)
        reals = functions.creat_reals_pyramid(real, reals, opt)
        all_reals.append(reals)

    return np.array(all_reals).T
Exemplo n.º 2
0
def train_single_scale(netD, netD_mask1, netD_mask2,netG,reals1, reals2, Gs,Zs,in_s1, in_s2,NoiseAmp,opt):

    real1 = reals1[len(Gs)]
    real2 = reals2[len(Gs)]

    if opt.replace_background:
        background_real1 = create_background(functions.convert_image_np(real1))
        real2 = create_img_over_background(functions.convert_image_np(real2), background_real1)

        plt.imsave('%s/background_real_scale1.png' % (opt.outf), background_real1, vmin=0, vmax=1)
        plt.imsave('%s/real_scale2_new.png' % (opt.outf), real2, vmin=0, vmax=1)

        real2 = functions.np2torch(real2, opt)

    # assumption: the images are the same size
    opt.nzx = real1.shape[2]#+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real1.shape[3]#+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size-1)*(opt.num_layer-1))*opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)

    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    fixed_noise = functions.generate_noise([opt.nc_z,opt.nzx,opt.nzy],device=opt.device)
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    l1_loss = nn.L1Loss()
    zero_mask_tensor = torch.zeros([1,opt.nzx,opt.nzy], device=opt.device)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay_d)
    optimizerD_masked1 = optim.Adam(netD_mask1.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay_d_mask1)
    optimizerD_masked2 = optim.Adam(netD_mask2.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay_d_mask2)
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999))

    if opt.cyclic_lr:
        schedulerD = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizerD, base_lr=opt.lr_d*opt.gamma, max_lr=opt.lr_d,
                                                       step_size_up=opt.niter/10, mode="triangular2",
                                                       cycle_momentum=False)
        schedulerD_masked1 = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizerD_masked1, base_lr=opt.lr_d*opt.gamma, max_lr=opt.lr_d,
                                                               step_size_up=opt.niter/10, mode="triangular2",
                                                               cycle_momentum=False)
        schedulerD__masked2 = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizerD_masked2, base_lr=opt.lr_d*opt.gamma, max_lr=opt.lr_d,
                                                                step_size_up=opt.niter/10, mode="triangular2",
                                                                cycle_momentum=False)
        schedulerG = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizerG, base_lr=opt.lr_d*opt.gamma, max_lr=opt.lr_d,
                                                       step_size_up=opt.niter/10, mode="triangular2",
                                                       cycle_momentum=False)
    else:
        schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,milestones=[1600],gamma=opt.gamma)
        schedulerD_masked1 = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD_masked1, milestones=[1600], gamma=opt.gamma)
        schedulerD__masked2 = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD_masked2, milestones=[1600], gamma=opt.gamma)
        schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,milestones=[1600],gamma=opt.gamma)

    discriminators = [netD, netD_mask1, netD_mask2]
    discriminators_optimizers = [optimizerD, optimizerD_masked1, optimizerD_masked2]
    discriminators_schedulers = [schedulerD, schedulerD_masked1, schedulerD__masked2]

    err_D_img1_2plot = []
    err_D_img2_2plot = []
    err_D_mask1_2plot = []
    err_D_mask2_2plot = []
    errG_total_loss_2plot = []
    errG_total_loss1_2plot = []
    errG_total_loss2_2plot = []
    errG_fake1_2plot = []
    errG_fake2_2plot = []
    D1_real2plot = []
    D2_real2plot = []
    D1_fake2plot = []
    D2_fake2plot = []
    l1_mask_loss2plot = []
    mask_loss2plot = []
    reconstruction_loss1_2plot = []
    reconstruction_loss2_2plot = []

    for epoch in range(opt.niter):
        """
        We want to ensure that there exists a specific set of input noise maps, which generates the original image x.
        We specifically choose {z*, 0, 0, ..., 0}, where z* is some fixed noise map.
        In the first scale, we create that z* (aka z_opt). On other scales, z_opt is just zeros (initialized above)
        """
        is_first_scale = len(Gs) == 0

        noise1_, z_opt1 = _create_noise_for_iteration(is_first_scale, m_noise, opt, z_opt, NoiseMode.Z1)
        noise2_, z_opt2 = _create_noise_for_iteration(is_first_scale, m_noise, opt, z_opt, NoiseMode.Z2)

        ############################
        # (1) Update D networks:
        # - netD: train with real on 2 input images (real1, real2)
        # - netD: train with fake on 2 fake images from different noise source (NoiseMode.Z1, NoiseMode.Z2)
        # if opt.enable_mask is ON, then:
        # - netD_mask1: train with real on real1
        # - netD_mask2: train with real on real2
        # - netD_mask1: train with fake on the generated fake image with mask1 applied on it
        # - netD_mask2: train with fake on the generated fake image with mask2 applied on it
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            for discriminator in discriminators:
                discriminator.zero_grad()

            errD_real1, D_x1 = discriminator_train_with_real(netD, opt, real1)
            errD_real2, D_x2 = discriminator_train_with_real(netD, opt, real2)

            # single discriminator for each image
            if opt.enable_mask:
                errD_mask1_real1, _ = discriminator_train_with_real(netD_mask1, opt, real1)
                errD_mask2_real2, _ = discriminator_train_with_real(netD_mask2, opt, real2)

            # train with fake
            in_s1, noise1, prev1, new_z_prev1 = _prepare_discriminator_train_with_fake_input(Gs, NoiseAmp, Zs, epoch,
                                                                                             in_s1, is_first_scale, j,
                                                                                             m_image, m_noise, noise1_,
                                                                                             opt, real1, reals1,
                                                                                             NoiseMode.Z1)
            in_s2, noise2, prev2, new_z_prev2 = _prepare_discriminator_train_with_fake_input(Gs, NoiseAmp, Zs, epoch,
                                                                                             in_s2, is_first_scale, j,
                                                                                             m_image, m_noise, noise2_,
                                                                                             opt, real2, reals2,
                                                                                             NoiseMode.Z2)
            if new_z_prev1 is not None:
                z_prev1 = new_z_prev1
            if new_z_prev2 is not None:
                z_prev2 = new_z_prev2

            # Z1 only:
            mixed_noise1 = functions.merge_noise_vectors(noise1, torch.zeros(noise1.shape, device=opt.device),
                                               opt.noise_vectors_merge_method)

            fake1_output = _generate_fake(netG, mixed_noise1, prev1)

            if opt.enable_mask:
                fake1, fake1_mask1, fake1_mask2 = fake1_output
            else:
                fake1 = fake1_output[0]

            D_G_z_1, errD_fake1, gradient_penalty1 = _train_discriminator_with_fake(netD, fake1, opt, real1)

            # Z2 only:
            mixed_noise2 = functions.merge_noise_vectors(torch.zeros(noise2.shape, device=opt.device), noise2,
                                               opt.noise_vectors_merge_method)

            fake2_output = _generate_fake(netG, mixed_noise2, prev2)
            if opt.enable_mask:
                fake2, fake2_mask1, fake2_mask2 = fake2_output
            else:
                fake2 = fake2_output[0]

            D_G_z_2, errD_fake2, gradient_penalty2 = _train_discriminator_with_fake(netD, fake2, opt, real2)

            if opt.enable_mask:
                _, errD_mask1_fake1, _ = _train_discriminator_with_fake(netD_mask1, fake1_mask1, opt, real1)
                _, errD_mask2_fake2, _ = _train_discriminator_with_fake(netD_mask2, fake2_mask2, opt, real2)

            errD_image1 = errD_real1 + errD_fake1 + gradient_penalty1
            errD_image2 = errD_real2 + errD_fake2 + gradient_penalty2

            for discriminator_optimizer in discriminators_optimizers:
                discriminator_optimizer.step()

        err_D_img1_2plot.append(errD_image1.detach())
        err_D_img2_2plot.append(errD_image2.detach())

        ############################
        # (2) Update G network:
        # - netG: train with fake on 2 fake images from different noise source (NoiseMode.Z1, NoiseMode.Z2) against netD
        # - netG: reconstruction loss against 2 real images
        # if opt.enable_mask is ON, then:
        # - netD_mask1: train with fake on the generated fake image with mask1 applied on it against netD_mask1
        # - netD_mask2: train with fake on the generated fake image with mask2 applied on it against netD_mask2
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            errG_fake1, D_fake1_map = _generator_train_with_fake(fake1, netD)
            errG_fake2, D_fake2_map = _generator_train_with_fake(fake2, netD)
            rec_loss1, Z_opt1 = _reconstruction_loss(alpha, netG, opt, z_opt1, z_prev1, real1, NoiseMode.Z1, opt.noise_amp1)
            rec_loss2, Z_opt2 = _reconstruction_loss(alpha, netG, opt, z_opt2, z_prev2, real2, NoiseMode.Z2, opt.noise_amp2)

            if opt.enable_mask:
                mask_loss_fake1_mask1, D_mask1_fake1_mask1_map = _generator_train_with_fake(fake1_mask1, netD_mask1)
                mask_loss_fake2_mask1, D_mask1_fake2_mask1_map = _generator_train_with_fake(fake2_mask1, netD_mask1)
                mask_loss_fake1_mask2, D_mask2_fake1_mask2_map = _generator_train_with_fake(fake1_mask2, netD_mask2)
                mask_loss_fake2_mask2, D_mask2_fake2_mask2_map = _generator_train_with_fake(fake2_mask2, netD_mask2)

            optimizerG.step()

        errG_total_loss1_2plot.append(errG_fake1.detach()+rec_loss1)
        errG_total_loss2_2plot.append(errG_fake2.detach()+rec_loss2)
        errG_fake1_2plot.append(errG_fake1.detach())
        errG_fake2_2plot.append(errG_fake2.detach())
        G_total_loss = errG_fake1.detach()+rec_loss1 + errG_fake2.detach()+rec_loss2
        errG_total_loss_2plot.append(G_total_loss)
        D1_real2plot.append(D_x1)
        D2_real2plot.append(D_x2)
        D1_fake2plot.append(D_G_z_1)
        D2_fake2plot.append(D_G_z_2)
        reconstruction_loss1_2plot.append(rec_loss1)
        reconstruction_loss2_2plot.append(rec_loss2)

        if epoch % 25 == 0 or epoch == (opt.niter-1):
            logger.info('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter-1):
            plt.imsave('%s/fake_sample1.png' %  (opt.outf), functions.convert_image_np(fake1.detach()), vmin=0, vmax=1)
            plt.imsave('%s/fake_sample2.png' % (opt.outf), functions.convert_image_np(fake2.detach()), vmin=0, vmax=1)
            plt.imsave('%s/G(z_opt1).png'    % (opt.outf),
                       functions.convert_image_np(netG(Z_opt1.detach(), z_prev1)[0].detach()), vmin=0, vmax=1)
            plt.imsave('%s/G(z_opt2).png' % (opt.outf),
                       functions.convert_image_np(netG(Z_opt2.detach(), z_prev2)[0].detach()), vmin=0, vmax=1)

            # torch.save((mixed_noise1, prev1), '%s/fake1_noise_source.pth' % (opt.outf))
            # torch.save((mixed_noise2, prev2), '%s/fake2_noise_source.pth' % (opt.outf))
            # torch.save((Z_opt1, z_prev1), '%s/G(z_opt1)_noise_source.pth' % (opt.outf))
            # torch.save((Z_opt2, z_prev2), '%s/G(z_opt2)_noise_source.pth' % (opt.outf))
            torch.save(z_opt1, '%s/z_opt1.pth' % (opt.outf))
            torch.save(z_opt2, '%s/z_opt2.pth' % (opt.outf))
            if epoch == (opt.niter-1) and opt.enable_mask:
                plt.imsave('%s/fake1_mask1.png' % (opt.outf), functions.convert_image_np(fake1_mask1.detach()), vmin=0,
                           vmax=1)
                plt.imsave('%s/fake2_mask1.png' % (opt.outf), functions.convert_image_np(fake2_mask1.detach()), vmin=0,
                           vmax=1)
                plt.imsave('%s/fake1_mask2.png' % (opt.outf), functions.convert_image_np(fake1_mask2.detach()), vmin=0,
                           vmax=1)
                plt.imsave('%s/fake2_mask2.png' % (opt.outf), functions.convert_image_np(fake2_mask2.detach()), vmin=0,
                           vmax=1)
                _imsave_discriminator_map(D_fake1_map, "D_fake1_map", opt)
                _imsave_discriminator_map(D_fake2_map, "D_fake2_map", opt)
                _imsave_discriminator_map(D_mask1_fake1_mask1_map, "D_mask1_fake1_mask1_map", opt)
                _imsave_discriminator_map(D_mask1_fake2_mask1_map, "D_mask1_fake2_mask1_map", opt)
                _imsave_discriminator_map(D_mask2_fake1_mask2_map, "D_mask2_fake1_mask2_map", opt)
                _imsave_discriminator_map(D_mask2_fake2_mask2_map, "D_mask2_fake2_mask2_map", opt)


        for discriminator_scheduler in discriminators_schedulers:
            discriminator_scheduler.step()
        schedulerG.step()

    functions.save_networks(netG,netD, netD_mask1, netD_mask2,z_opt1, z_opt2,opt)

    functions.plot_learning_curves("G_loss", opt.niter, [errG_total_loss_2plot,
                                                         errG_total_loss1_2plot, errG_total_loss2_2plot,
                                                         errG_fake1_2plot, errG_fake2_2plot,
                                                         reconstruction_loss1_2plot,
                                                         reconstruction_loss2_2plot],
                                   ["G_total_loss", "G_total_loss1", "G_total_loss2",
                                    "G_fake1_loss", "G_fake2_loss",
                                    "G_recon_loss_1", "G_recon_loss_2"], opt.outf)
    d_plots = [err_D_img1_2plot, err_D_img2_2plot]
    d_labels = ["D1_total_loss", "D2_total_loss"]

    functions.plot_learning_curves("D_loss", opt.niter, d_plots, d_labels, opt.outf)
    functions.plot_learning_curves("G_vs_D_loss", opt.niter,
                                   [errG_total_loss_2plot, errG_total_loss1_2plot, errG_total_loss2_2plot,
                                    err_D_img1_2plot, err_D_img2_2plot],
                                   ["G_total_loss", "G_total_loss1", "G_total_loss2", "D1_total_loss", "D2_total_loss"],
                                   opt.outf)
    return (z_opt1, z_opt2), (in_s1, in_s2), netG
Exemplo n.º 3
0
def SinGAN_anchor_generate(Gs,
                           Zs,
                           reals,
                           NoiseAmp,
                           opt,
                           in_s=None,
                           scale_v=1,
                           scale_h=1,
                           n=0,
                           gen_start_scale=0,
                           num_samples=1,
                           anchor_image=None,
                           direction=None,
                           transfer=None,
                           noise_solutions=None,
                           factor=None,
                           base=None,
                           insert_limit=0):

    #### Loading in Anchor if Needed #####
    anchor = anchor_image
    if anchor is not None:
        anchors = []
        anchor = functions.np2torch(anchor_image, opt)
        anchor_ = imresize(anchor, opt.scale1, opt)
        anchors = functions.creat_reals_pyramid(anchor_, anchors,
                                                opt)  #high key hacky code
    if direction is not None:
        directions = []
        direction = functions.np2torch(direction, opt)
        direction_ = imresize(direction, opt.scale1, opt)
        directions = functions.creat_reals_pyramid(direction_, directions,
                                                   opt)  #high key hacky code
    if base is not None:
        bases = []
        base = functions.np2torch(base, opt)
        base_ = imresize(base, opt.scale1, opt)
        bases = functions.creat_reals_pyramid(base_, bases,
                                              opt)  #high key hacky code
    #### MY CODE ####

    #if torch.is_tensor(in_s) == False:
    if in_s is None:
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    images_cur = []
    for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp):
        pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2
        m = nn.ZeroPad2d(int(pad1))
        nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v
        nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h

        images_prev = images_cur
        images_cur = []

        for i in range(0, num_samples, 1):
            if n == 0:  #COARSEST SCALE
                z_curr = functions.generate_noise([1, nzx, nzy],
                                                  device=opt.device)
                z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3])
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z, nzx, nzy],
                                                  device=opt.device)
                z_curr = m(z_curr)

            z_orig = z_curr

            if images_prev == []:  #FIRST GENERATION IN COARSEST SCALE
                I_prev = m(in_s)

            else:  #NOT FIRST GENERATION, BUT AT COARSEST SCALE
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)  #upscale
                #print(n)
                if opt.mode != "SR":
                    I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]),
                                    0:round(scale_h * reals[n].shape[3])]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]]
                    I_prev = functions.upsampling(
                        I_prev, z_curr.shape[2],
                        z_curr.shape[3])  #make it fit padded noise
                else:
                    #prev_before = I_prev #MY ADDITION
                    I_prev = m(I_prev)

            if n < gen_start_scale:  #anything less than final
                z_curr = Z_opt  #Z_opt comes from trained pyramid....
            z_in = noise_amp * (z_curr) + I_prev

            if noise_solutions is not None:
                z_curr = noise_solutions[n]

                z_in = (1 - factor) * noise_amp * (
                    z_curr
                ) + I_prev + factor * noise_amp * z_orig  #adds in previous image to z_opt'''

            I_curr = G(z_in.detach(), I_prev)
            if base is not None:
                if n == insert_limit:
                    I_curr = bases[n] * factor + I_curr * (1 - factor)

            if anchor is not None and direction is not None:
                anchor_curr = anchors[n]
                I_curr = reinforcement(anchor_curr, I_curr, directions[n])
                #I_curr = reinforcement_sigmoid(anchor_curr, I_curr, direction, n)
            ###### ENFORCE LH = ANCHOR FOR IMAGE #######

            if n == opt.stop_scale:  #hacky code
                if anchor is not None and direction is not None:
                    anchor_curr = anchors[n]
                    I_curr = reinforcement(anchor_curr, I_curr, direction)
                    #I_curr = reinforcement_sigmoid(anchor_curr, I_curr, direction, n)
                array = functions.convert_image_np(I_curr.detach())
            images_cur.append(I_curr)
        n += 1
    return array
Exemplo n.º 4
0
def invert_model(test_image,
                 model_name,
                 scales2invert=None,
                 penalty=1e-3,
                 show=True):
    '''test_image is an array, model_name is a name'''
    Noise_Solutions = []

    parser = get_arguments()
    parser.add_argument('--input_dir',
                        help='input image dir',
                        default='Input/Images')

    parser.add_argument('--mode', default='RandomSamples')
    opt = parser.parse_args("")
    opt.input_name = model_name
    opt.reg = penalty

    if model_name == 'islands2_basis_2.jpg':  #HARDCODED
        opt.scale_factor = 0.6

    opt = functions.post_config(opt)

    ### Loading in Generators
    Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt)
    for G in Gs:
        G = functions.reset_grads(G, False)
        G.eval()

    ### Loading in Ground Truth Test Images
    reals = []  #deleting old real images
    real = functions.np2torch(test_image, opt)
    functions.adjust_scales2image(real, opt)

    real_ = functions.np2torch(test_image, opt)
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)

    ### General Padding
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    m_noise = nn.ZeroPad2d(int(pad_noise))

    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    m_image = nn.ZeroPad2d(int(pad_image))

    I_prev = None
    REC_ERROR = 0

    if scales2invert is None:
        scales2invert = opt.stop_scale + 1

    for scale in range(scales2invert):
        #for scale in range(3):

        #Get X, G
        X = reals[scale]
        G = Gs[scale]
        noise_amp = NoiseAmp[scale]

        #Defining Dimensions
        opt.nc_z = X.shape[1]
        opt.nzx = X.shape[2]
        opt.nzy = X.shape[3]

        #getting parameters for prior distribution penalty
        pdf = torch.distributions.Normal(0, 1)
        alpha = opt.reg
        #alpha = 1e-2

        #Defining Z
        if scale == 0:
            z_init = functions.generate_noise(
                [1, opt.nzx, opt.nzy], device=opt.device)  #only 1D noise
        else:
            z_init = functions.generate_noise(
                [3, opt.nzx, opt.nzy],
                device=opt.device)  #otherwise move up to 3d noise

        z_init = Variable(z_init.cuda(),
                          requires_grad=True)  #variable to optimize

        #Building I_prev
        if I_prev == None:  #first scale scenario
            in_s = torch.full(reals[0].shape, 0, device=opt.device)  #all zeros
            I_prev = in_s
            I_prev = m_image(I_prev)  #padding

        else:  #otherwise take the output from the previous scale and upsample
            I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)  #upsamples
            I_prev = m_image(I_prev)
            I_prev = I_prev[:, :, 0:X.shape[2] + 10, 0:X.shape[
                3] + 10]  #making sure that precision errors don't mess anything up
            I_prev = functions.upsampling(I_prev, X.shape[2] + 10, X.shape[3] +
                                          10)  #seems to be redundant

        LR = [2e-3, 2e-2, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1]
        Zoptimizer = torch.optim.RMSprop([z_init],
                                         lr=LR[scale])  #Defining Optimizer
        x_loss = []  #for plotting
        epochs = []  #for plotting

        niter = [
            200, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400,
            400, 400
        ]
        for epoch in range(niter[scale]):  #Gradient Descent on Z

            if scale == 0:
                noise_input = m_noise(z_init.expand(1, 3, opt.nzx,
                                                    opt.nzy))  #expand and padd
            else:
                noise_input = m_noise(z_init)  #padding

            z_in = noise_amp * noise_input + I_prev
            G_z = G(z_in, I_prev)

            x_recLoss = F.mse_loss(G_z, X)  #MSE loss

            logProb = pdf.log_prob(z_init).mean()  #Gaussian loss

            loss = x_recLoss - (alpha * logProb.mean())

            Zoptimizer.zero_grad()
            loss.backward()
            Zoptimizer.step()

            #losses['rec'].append(x_recLoss.data[0])
            #print('Image loss: [%d] loss: %0.5f' % (epoch, x_recLoss.item()))
            #print('Noise loss: [%d] loss: %0.5f' % (epoch, z_recLoss.item()))
            x_loss.append(loss.item())
            epochs.append(epoch)

            REC_ERROR = x_recLoss

        if show:
            plt.plot(epochs, x_loss, label='x_loss')
            plt.legend()
            plt.show()

        I_prev = G_z.detach(
        )  #take final output, maybe need to edit this line something's very very fishy

        _ = show_image(X, show, 'target')
        reconstructed_image = show_image(I_prev, show, 'output')
        _ = show_image(noise_input.detach().cpu(), show, 'noise')

        Noise_Solutions.append(noise_input.detach())
    return Noise_Solutions, reconstructed_image, REC_ERROR
Exemplo n.º 5
0
            print('*** Train SinGAN for SR ***')
            real = functions.read_image(opt)
            opt.min_size = 18
            real = functions.adjust_scales2image_SR(real, opt)
            train(opt, Gs, Zs, reals, NoiseAmp)
            opt.mode = mode
        print('%f' % pow(in_scale, iter_num))
        Zs_sr = []
        reals_sr = []
        NoiseAmp_sr = []
        Gs_sr = []
        clean_img = img.imread(
            '%s/%s' % (opt.input_dir, opt.input_name)).astype(np.float32)
        noisy_img = clean_img + np.random.normal(0, float(opt.noise),
                                                 clean_img.shape)
        noisy_img = functions.np2torch(noisy_img, opt)
        noisy_img = noisy_img[:, 0:3, :, :]
        clean_img = clean_img[:, :, 0:3] / 255.0

        real_ = noisy_img
        opt.scale_factor = 1 / in_scale
        opt.scale_factor_init = 1 / in_scale
        for j in range(1, iter_num + 1, 1):
            real_ = imresize(real_, pow(1 / opt.scale_factor, 1), opt)
            reals_sr.append(real_)
            Gs_sr.append(Gs[-1])
            NoiseAmp_sr.append(NoiseAmp[-1])
            z_opt = torch.full(real_.shape, 0, device=opt.device)
            m = nn.ZeroPad2d(5)
            z_opt = m(z_opt)
            Zs_sr.append(z_opt)