コード例 #1
0
def SinGAN_generate(Gs,Zs,reals,NoiseAmp,opt,in_s=None,scale_v=1,scale_h=1,n=0,gen_start_scale=0,num_samples=50):
    #if torch.is_tensor(in_s) == False:
    if in_s is None:
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    images_cur = []
    for G,Z_opt,noise_amp in zip(Gs,Zs,NoiseAmp):
        pad1 = ((opt.ker_size-1)*opt.num_layer)/2
        m = nn.ZeroPad2d(int(pad1))
        nzx = (Z_opt.shape[2]-pad1*2)*scale_v
        nzy = (Z_opt.shape[3]-pad1*2)*scale_h

        images_prev = images_cur
        images_cur = []

        for i in range(0,num_samples,1):
            if n == 0:
                z_curr = functions.generate_noise([1,nzx,nzy], device=opt.device)
                z_curr = z_curr.expand(1,3,z_curr.shape[2],z_curr.shape[3])
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device)
                z_curr = m(z_curr)

            if images_prev == []:
                I_prev = m(in_s)
                #I_prev = m(I_prev)
                #I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                #I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev,1/opt.scale_factor, opt)
                if opt.mode != "SR":
                    I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                    I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
                else:
                    I_prev = m(I_prev)

            if n < gen_start_scale:
                z_curr = Z_opt

            z_in = noise_amp*(z_curr)+I_prev
            I_curr = G(z_in.detach(),I_prev)

            if opt.mode == 'train':
                dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale)
            else:
                dir2save = functions.generate_dir2save(opt)
            try:
                os.makedirs(dir2save)
            except OSError:
                pass
            if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"):
                plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1)
                #plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1)
                #plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1)
            images_cur.append(I_curr)
        n+=1
    return I_curr.detach()
コード例 #2
0
def draw_concat(Gs,Zs,reals,NoiseAmp,in_s,mode,m_noise,m_image,opt):
    G_z = in_s
    if len(Gs) > 0:
        if mode == 'rand':
            count = 0
            pad_noise = int(((opt.ker_size-1)*opt.num_layer)/2)
            if opt.mode == 'animation_train':
                pad_noise = 0
            for G,Z_opt,real_curr,real_next,noise_amp in zip(Gs,Zs,reals,reals[1:],NoiseAmp):
                if count == 0:
                    z = functions.generate_noise([1, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise], device=opt.device)
                    z = z.expand(1, 3, z.shape[2], z.shape[3])
                else:
                    z = functions.generate_noise([opt.nc_z,Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise], device=opt.device)
                z = m_noise(z)
                G_z = G_z[:,:,0:real_curr.shape[2],0:real_curr.shape[3]]
                G_z = m_image(G_z)
                z_in = noise_amp*z+G_z
                G_z = G(z_in.detach(),G_z)
                G_z = imresize(G_z,1/opt.scale_factor,opt)
                G_z = G_z[:,:,0:real_next.shape[2],0:real_next.shape[3]]
                count += 1
        if mode == 'rec':
            count = 0
            for G,Z_opt,real_curr,real_next,noise_amp in zip(Gs,Zs,reals,reals[1:],NoiseAmp):
                G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]]
                G_z = m_image(G_z)
                z_in = noise_amp*Z_opt+G_z
                G_z = G(z_in.detach(),G_z)
                G_z = imresize(G_z,1/opt.scale_factor,opt)
                G_z = G_z[:,:,0:real_next.shape[2],0:real_next.shape[3]]
                #if count != (len(Gs)-1):
                #    G_z = m_image(G_z)
                count += 1
    return G_z
コード例 #3
0
def draw_concat(Gs,Zs,reals,NoiseAmp,in_s,mode,m_noise,m_image,opt):
    G_z = in_s
    if len(Gs) > 0:
        if mode == 'rand':
            count = 0
            pad_noise = int(((opt.ker_size-1)*opt.num_layer)/2)
            if opt.mode == 'animation_train':
                pad_noise = 0
            for G,Z_opt,real_curr,real_next,noise_amp in zip(Gs,Zs,reals,reals[1:],NoiseAmp):
                if count == 0:
                    z = functions.generate_noise([1, Z_opt.shape[1] - 2 * pad_noise, Z_opt.shape[2] - 2 * pad_noise])
                    z = tf.broadcast_to(z, [1, z.shape[1], z.shape[2], 3])
                else:
                    z = functions.generate_noise([opt.nc_z,Z_opt.shape[1] - 2 * pad_noise, Z_opt.shape[2] - 2 * pad_noise])
                
                z = m_noise(z)
                G_z = G_z[:,0:real_curr.shape[1],0:real_curr.shape[2],:] #PY: NCWH, TF:NWHC
                G_z = m_image(G_z)
                z_in = noise_amp*z+G_z
                G_z = G(z_in,G_z, training=True)
                G_z = imresize(G_z,1/opt.scale_factor,opt)
                G_z = G_z[:,0:real_next.shape[1],0:real_next.shape[2], :]
                count += 1
        if mode == 'rec':
            count = 0
            for G,Z_opt,real_curr,real_next,noise_amp in zip(Gs,Zs,reals,reals[1:],NoiseAmp):
                G_z = G_z[:, 0:real_curr.shape[1], 0:real_curr.shape[2], :]
                G_z = m_image(G_z)
                z_in = noise_amp*Z_opt+G_z
                G_z = G(z_in,G_z, training=True)
                G_z = imresize(G_z,1/opt.scale_factor,opt)
                G_z = G_z[:,0:real_next.shape[1],0:real_next.shape[2], :]
                count += 1
    return G_z
コード例 #4
0
ファイル: training.py プロジェクト: codeconomics/SinGAN
def draw_concat(Gs, Zs, reals, NoiseAmp, in_s, mode, m_noise, m_image, opt):
    G_z = in_s
    # if it's not the first scale, else do nothign
    if len(Gs) > 0:
        # if in random mode
        if mode == 'rand':
            count = 0
            pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
            #from each scale
            for G, Z_opt, real_curr, real_next, noise_amp in zip(
                    Gs, Zs, reals, reals[1:], NoiseAmp):
                # for the first loop
                if count == 0:
                    #generate the noise
                    z = functions.generate_noise([
                        1, Z_opt.shape[2] - 2 * pad_noise,
                        Z_opt.shape[3] - 2 * pad_noise
                    ],
                                                 device=opt.device)
                    #broadcast it to correct shape
                    z = z.expand(1, 3, z.shape[2], z.shape[3])
                else:
                    #direct generate the noise
                    z = functions.generate_noise([
                        opt.nc_z, Z_opt.shape[2] - 2 * pad_noise,
                        Z_opt.shape[3] - 2 * pad_noise
                    ],
                                                 device=opt.device)
                #padding the noise
                z = m_noise(z)
                #------------------------------------------------------------
                #generate a shape of current real image's [width,height] from G_z(in_s)
                G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]]
                #padding it with images
                G_z = m_image(G_z)
                #amplify the generated noise, then add with the G_z
                z_in = noise_amp * z + G_z
                #generate a new output from generator
                G_z = G(z_in.detach(), G_z)
                #resize the graph with 1/opt.scale_factor
                G_z = imresize(G_z, 1 / opt.scale_factor, opt)
                #generate a shape of current real image's [width,height] from G_z(in_s)
                G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]]
                count += 1
        if mode == 'rec':
            count = 0
            #from each scale
            for G, Z_opt, real_curr, real_next, noise_amp in zip(
                    Gs, Zs, reals, reals[1:], NoiseAmp):
                # do same thing except
                G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]]
                G_z = m_image(G_z)
                z_in = noise_amp * Z_opt + G_z  # for here we use Z_opt instead of generated noise
                G_z = G(z_in.detach(), G_z)
                G_z = imresize(G_z, 1 / opt.scale_factor, opt)
                G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]]
                count += 1
    return G_z
コード例 #5
0
def draw_concat(Gs, Zs, reals, NoiseAmp, in_s, mode, m_noise, m_image, opt):
    """ Generate through all higher level Gs """
    G_z = in_s  # G_z is the current image output
    if len(
            Gs
    ) > 0:  # skipped for the initial pyr level, since there is no previous G to generate
        if mode == 'rand':  # using random noise map Z_opt
            count = 0
            pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
            if opt.mode == 'animation_train':
                pad_noise = 0
            for G, Z_opt, real_curr, real_next, noise_amp in zip(
                    Gs, Zs, reals, reals[1:], NoiseAmp):
                if count == 0:  # Z_opt is not really used, except its size.
                    z = functions.generate_noise([
                        1, Z_opt.shape[2] - 2 * pad_noise,
                        Z_opt.shape[3] - 2 * pad_noise
                    ],
                                                 device=opt.device)
                    z = z.expand(1, 3, z.shape[2],
                                 z.shape[3])  # same value along color channel
                else:
                    z = functions.generate_noise(
                        [
                            opt.nc_z, Z_opt.shape[2] - 2 * pad_noise,
                            Z_opt.shape[3] - 2 * pad_noise
                        ],
                        device=opt.device)  # noise including color
                z = m_noise(z)
                G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]]
                G_z = m_image(G_z)
                z_in = noise_amp * z + G_z
                G_z = G(z_in.detach(), G_z)
                G_z = imresize(G_z, 1 / opt.scale_factor,
                               opt)  # upsample it to current level
                G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]]
                count += 1
        if mode == 'rec':  # using reconstruction vectors Z_opt
            count = 0
            for G, Z_opt, real_curr, real_next, noise_amp in zip(
                    Gs, Zs, reals, reals[1:], NoiseAmp):
                G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[
                    3]]  # make sure the size is the same as real pyr
                G_z = m_image(G_z)
                z_in = noise_amp * Z_opt + G_z  # use the loaded noise amplitude
                G_z = G(z_in.detach(),
                        G_z)  # THis is the iteration equation for G_z
                G_z = imresize(G_z, 1 / opt.scale_factor,
                               opt)  # upsample it to current level
                G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[
                    3]]  # make sure the size is the same as real pyr
                #if count != (len(Gs)-1):
                #    G_z = m_image(G_z)
                count += 1
    return G_z
コード例 #6
0
def _create_noise_for_draw_concat(opt, count, pad_noise, m_noise, Z_opt, noise_mode):
    if count == 0:
        z = functions.generate_noise([1, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise],
                                     device=opt.device, noise_mode=noise_mode,
                                     gaussian_noise_z_distance=opt.gaussian_noise_z_distance)
        z = z.expand(1, 3, z.shape[2], z.shape[3])
    else:
        z = functions.generate_noise([opt.nc_z, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise],
                                     device=opt.device, noise_mode=noise_mode,
                                     gaussian_noise_z_distance=opt.gaussian_noise_z_distance)
    z = m_noise(z)
    return z
コード例 #7
0
def compute_z_diff(n, Z_opt, z_prev1, z_prev2, beta, device):
    """ compute z_diff_n(t+1) """
    nzx, nzy = Z_opt.shape[2], Z_opt.shape[3]
    nc_z = 3
    if n == 0:
        z_rand = functions.generate_noise([1, nzx, nzy], device=device)
        # make z_rand same across channels
        z_rand = z_rand.expand(1, 3, Z_opt.shape[2], Z_opt.shape[3])
        z_diff = beta * (z_prev1 - z_prev2) + (1 - beta) * z_rand
    else:
        z_diff = beta * (z_prev1 - z_prev2) + (1 - beta) * (
            functions.generate_noise([nc_z, nzx, nzy], device=device))
    return z_diff
コード例 #8
0
def _create_noise_for_iteration(is_first_scale, m_noise, opt, default_z_opt, noise_mode):
    if is_first_scale:
        z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device, noise_mode=noise_mode,
                                         gaussian_noise_z_distance=opt.gaussian_noise_z_distance)
        z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
        noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device, noise_mode=noise_mode,
                                          gaussian_noise_z_distance=opt.gaussian_noise_z_distance)
        noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
    else:
        z_opt = default_z_opt
        noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device, noise_mode=noise_mode,
                                          gaussian_noise_z_distance=opt.gaussian_noise_z_distance)
        noise_ = m_noise(noise_)
    return noise_, z_opt
コード例 #9
0
ファイル: manipulate.py プロジェクト: lior1990/SinGAN
def _generate_noise_for_sampling(m, n, nzx, nzy, opt, noise_mode):
    if n == 0:
        z_curr = functions.generate_noise(
            [1, nzx, nzy],
            device=opt.device,
            noise_mode=noise_mode,
            gaussian_noise_z_distance=opt.gaussian_noise_z_distance)
        z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3])
        z_curr = m(z_curr)
    else:
        z_curr = functions.generate_noise(
            [opt.nc_z, nzx, nzy],
            device=opt.device,
            noise_mode=noise_mode,
            gaussian_noise_z_distance=opt.gaussian_noise_z_distance)
        z_curr = m(z_curr)
    return z_curr
コード例 #10
0
def compute_z_prev(n, Z_opt, device):
    """
    compute z_n at previous time, i.e. z_n(t), z_n(t-1)

    :param:
        n -- int, indicate scale level (0 = first generator, i.e. coarest level)
        Z_opt -- input noise at the n-th scale (gaussian noise at first generator, elsewhere 0)
        device -- torch.device, CUDA / CPU
    """
    nzx, nzy = Z_opt.shape[2], Z_opt.shape[3]
    # no. of channel for noise input
    nc_z = 3
    if n == 0:
        # z_rand is gaussian noise
        z_rand = functions.generate_noise([1, nzx, nzy], device=device)
        z_rand = z_rand.expand(1, 3, Z_opt.shape[2], Z_opt.shape[3])
        z_prev1 = 0.95 * Z_opt + 0.05 * z_rand
        z_prev2 = Z_opt
    else:
        z_prev1 = 0.95 * Z_opt + 0.05 * functions.generate_noise(
            [nc_z, nzx, nzy], device=device)
        z_prev2 = Z_opt
    return z_prev1, z_prev2
コード例 #11
0
ファイル: anchoring.py プロジェクト: tboen1/MESIGAN
def SinGAN_anchor_generate(Gs,
                           Zs,
                           reals,
                           NoiseAmp,
                           opt,
                           in_s=None,
                           scale_v=1,
                           scale_h=1,
                           n=0,
                           gen_start_scale=0,
                           num_samples=1,
                           anchor_image=None,
                           direction=None,
                           transfer=None,
                           noise_solutions=None,
                           factor=None,
                           base=None,
                           insert_limit=0):

    #### Loading in Anchor if Needed #####
    anchor = anchor_image
    if anchor is not None:
        anchors = []
        anchor = functions.np2torch(anchor_image, opt)
        anchor_ = imresize(anchor, opt.scale1, opt)
        anchors = functions.creat_reals_pyramid(anchor_, anchors,
                                                opt)  #high key hacky code
    if direction is not None:
        directions = []
        direction = functions.np2torch(direction, opt)
        direction_ = imresize(direction, opt.scale1, opt)
        directions = functions.creat_reals_pyramid(direction_, directions,
                                                   opt)  #high key hacky code
    if base is not None:
        bases = []
        base = functions.np2torch(base, opt)
        base_ = imresize(base, opt.scale1, opt)
        bases = functions.creat_reals_pyramid(base_, bases,
                                              opt)  #high key hacky code
    #### MY CODE ####

    #if torch.is_tensor(in_s) == False:
    if in_s is None:
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    images_cur = []
    for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp):
        pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2
        m = nn.ZeroPad2d(int(pad1))
        nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v
        nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h

        images_prev = images_cur
        images_cur = []

        for i in range(0, num_samples, 1):
            if n == 0:  #COARSEST SCALE
                z_curr = functions.generate_noise([1, nzx, nzy],
                                                  device=opt.device)
                z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3])
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z, nzx, nzy],
                                                  device=opt.device)
                z_curr = m(z_curr)

            z_orig = z_curr

            if images_prev == []:  #FIRST GENERATION IN COARSEST SCALE
                I_prev = m(in_s)

            else:  #NOT FIRST GENERATION, BUT AT COARSEST SCALE
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)  #upscale
                #print(n)
                if opt.mode != "SR":
                    I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]),
                                    0:round(scale_h * reals[n].shape[3])]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]]
                    I_prev = functions.upsampling(
                        I_prev, z_curr.shape[2],
                        z_curr.shape[3])  #make it fit padded noise
                else:
                    #prev_before = I_prev #MY ADDITION
                    I_prev = m(I_prev)

            if n < gen_start_scale:  #anything less than final
                z_curr = Z_opt  #Z_opt comes from trained pyramid....
            z_in = noise_amp * (z_curr) + I_prev

            if noise_solutions is not None:
                z_curr = noise_solutions[n]

                z_in = (1 - factor) * noise_amp * (
                    z_curr
                ) + I_prev + factor * noise_amp * z_orig  #adds in previous image to z_opt'''

            I_curr = G(z_in.detach(), I_prev)
            if base is not None:
                if n == insert_limit:
                    I_curr = bases[n] * factor + I_curr * (1 - factor)

            if anchor is not None and direction is not None:
                anchor_curr = anchors[n]
                I_curr = reinforcement(anchor_curr, I_curr, directions[n])
                #I_curr = reinforcement_sigmoid(anchor_curr, I_curr, direction, n)
            ###### ENFORCE LH = ANCHOR FOR IMAGE #######

            if n == opt.stop_scale:  #hacky code
                if anchor is not None and direction is not None:
                    anchor_curr = anchors[n]
                    I_curr = reinforcement(anchor_curr, I_curr, direction)
                    #I_curr = reinforcement_sigmoid(anchor_curr, I_curr, direction, n)
                array = functions.convert_image_np(I_curr.detach())
            images_cur.append(I_curr)
        n += 1
    return array
コード例 #12
0
def SinGAN_generate(Gs,
                    Zs,
                    reals,
                    NoiseAmp,
                    opt,
                    in_s=None,
                    scale_v=1,
                    scale_h=1,
                    n=0,
                    gen_start_scale=0,
                    num_samples=50):
    real = functions.read_image(opt)

    real = real.numpy()
    real = resize(real, reals[-1].shape)
    real = torch.from_numpy(real)

    new_reals = creat_reals_pyramid(real, [], opt)
    buffer = []

    for new_real, real in zip(new_reals, reals):
        ele = new_real.numpy()
        ele = resize(ele, real.shape)
        ele = torch.from_numpy(ele)
        buffer.append(ele)
    reals = buffer

    for i, real_img in enumerate(reals):
        dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (
            opt.out, opt.input_name[:-4], gen_start_scale)
        plt.imsave('%s/%s_%d.png' % (dir2save, "real", i),
                   functions.convert_image_np(real_img.detach()),
                   vmin=0,
                   vmax=1)

    if in_s is None:
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    images_cur = []

    for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp):
        pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2
        m = nn.ZeroPad2d(int(pad1))
        nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v
        nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h
        # For Section IV
        # if n == 0:
        #     images_prev = images_cur
        # else:
        #     new_img_prev = []
        #     for img in images_cur:
        #         ele = reals[n].numpy()
        #         ele = resize(ele, img.shape)
        #         ele = torch.from_numpy(ele)
        #         new_img_prev.append(ele)
        #     images_prev = new_img_prev

        images_prev = images_cur

        # if n != 0:
        #     dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale)
        #     plt.imsave('%s/%s_%d.png' % (dir2save, "img_cur", n), functions.convert_image_np(images_prev[0].detach()), vmin=0,vmax=1)
        #     plt.imsave('%s/%s_%d.png' % (dir2save, "img_prev", n), functions.convert_image_np(images_cur[0].detach()), vmin=0,vmax=1)

        images_cur = []

        for i in range(0, num_samples, 1):
            if n == 0:
                z_curr = functions.generate_noise([1, nzx, nzy],
                                                  device=opt.device)
                z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3])
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z, nzx, nzy],
                                                  device=opt.device)
                z_curr = m(z_curr)

            if images_prev == []:
                I_prev = m(in_s)

            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)
                if opt.mode != "SR":
                    I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]),
                                    0:round(scale_h * reals[n].shape[3])]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]]
                    I_prev = functions.upsampling(I_prev, z_curr.shape[2],
                                                  z_curr.shape[3])
                else:
                    I_prev = m(I_prev)

            if n < gen_start_scale:
                z_curr = Z_opt

            z_in = noise_amp * (z_curr) + I_prev
            if opt.skip != '' and int(opt.skip) == n:
                I_curr = I_prev
            else:
                I_curr = G(z_in.detach(), I_prev)

            if n == len(reals) - 1:
                if opt.mode == 'train':
                    dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (
                        opt.out, opt.input_name[:-4], gen_start_scale)
                else:
                    dir2save = functions.generate_dir2save(opt)
                try:
                    os.makedirs(dir2save)
                except OSError:
                    pass
                if (opt.mode != "harmonization") & (opt.mode != "editing") & (
                        opt.mode != "SR") & (opt.mode != "paint2image"):
                    plt.imsave('%s/%d.png' % (dir2save, i),
                               functions.convert_image_np(I_curr.detach()),
                               vmin=0,
                               vmax=1)
                    # plt.imsave('%s/%d_%d.png' % (dir2save, i, n), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1)

            # For Section VI
            # if opt.mode == 'train':
            #     dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale)
            # else:
            #     dir2save = functions.generate_dir2save(opt)
            # try:
            #     os.makedirs(dir2save)
            # except OSError:
            #     pass
            # if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"):
            #     plt.imsave('%s/%d_%d.png' % (dir2save, i, n), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1)

            images_cur.append(I_curr)
        n += 1
    return I_curr.detach()
コード例 #13
0
ファイル: training.py プロジェクト: codeconomics/SinGAN
def train_single_scale(
        netD,  #current discriminator
        netG,  #current generator
        reals,  #the list of all resized data
        Gs,  #generator list
        Zs,  #
        in_s,  #
        NoiseAmp,  #
        opt,  #parameters
        centers=None):

    real = reals[len(Gs)]  # get the current resized real picture

    #get the x and y
    opt.nzx = real.shape[2]
    opt.nzy = real.shape[3]

    #receptive field
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride

    #padding width
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)

    #this stuff create a torch.nn class adding 0 pads, tf is slightly harder
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    #get alpha from opt
    alpha = opt.alpha

    #generate a noise in the following size
    fixed_noise = functions.generate_noise(
        [
            opt.nc_z,  #noise # channels
            opt.nzx,
            opt.nzy
        ],
        device=opt.device)

    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    #generate a tensor of size fixed_noise.shape filled with 0.

    z_opt = m_noise(z_opt)
    #give it a zero pad with width int(pad_noise)

    # setup optimizer and learning rate
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    #some plot list
    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    #for iteration number' loop
    for epoch in tqdm_notebook(range(opt.niter),
                               desc=f"scale {len(Gs)}",
                               leave=False):
        #if it's the first graph, for G need an additional imput
        if (Gs == []):
            #generate a noise of size [1,opt.nzx,opt.nzy]
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                             device=opt.device)
            #give it a zero pad with width int(pad_noise)
            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            #generate another noise
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy],
                                              device=opt.device)
            #give it additional dimention with size 3, in all these dimension all 3 layers are the same
            #the padding it in all the dimension
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        # when it's not the first graph
        else:
            #nc_z is 'noise # channels'
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                              device=opt.device)
            #the padding it in all the dimension
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        # for Discriminator inner steps' loop
        for j in range(opt.Dsteps):

            # train with real
            #before training reset grad, torch operation
            netD.zero_grad()
            #generate a result
            output = netD(real).to(opt.device)
            #error for D, to minimize -(D(x) + D(G(z))), the mean should be -1
            errD_real = -output.mean()  #-a
            # have all the gradients computed
            errD_real.backward(retain_graph=True)
            #return the list with all dictionary keys with negative values
            D_x = -errD_real.item()

            # train with fake
            # for the first loop in the first epoch
            if (j == 0) & (epoch == 0):
                #if it's the first scale
                if (Gs == []):
                    #set prev to all 0
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)

                    in_s = prev
                    #zero padding with width int(pad_image)
                    prev = m_image(prev)

                    #set z_prev to all 0
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)

                    #padding with noise
                    z_prev = m_noise(z_prev)

                    #set amp = 1
                    opt.noise_amp = 1
                else:
                    # generate the prev from rand mode
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    #zero padding with width int(pad_image)
                    prev = m_image(prev)
                    # generate the z_prev from rec mode
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)

                    #use opt.noise_amp_init*RMSE as the loss
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    # add a padding of width (int(pad_image))
                    z_prev = m_image(z_prev)
            else:
                #generate the prev form rand mode
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)
                # add a padding of width (int(pad_image))
                prev = m_image(prev)

            #if it's the first scale
            if (Gs == []):
                #a noise added additional dimention with size 3, in all these dimension all 3 layers are the same
                #the padding it in all the dimension
                noise = noise_
            else:
                #amplify the padded noise + prev
                noise = opt.noise_amp * noise_ + prev

            # generate the fake graph with noise
            # detach() detaches the output from the computationnal graph.
            # So no gradient will be backproped along this variable
            # in the very first loop G is RAW now
            fake = netG(noise.detach(), prev)
            # generate the output
            output = netD(fake.detach())

            # generate the error from fake, to minimize -(D(x) + D(G(z))), the mean should be positive
            errD_fake = output.mean()
            # have all the gradients computed
            errD_fake.backward(retain_graph=True)
            #get the discriminator
            D_G_z = output.mean().item()

            #calculate the penalty
            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device)
            #calculate gradient
            gradient_penalty.backward()

            #calculate penal D
            errD = errD_real + errD_fake + gradient_penalty

            #updates the parameters.
            optimizerD.step()

        #add the stuff into a record
        errD2plot.append(errD.detach())

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            # init to 0
            netG.zero_grad()
            # generate the output from the discrimator
            output = netD(fake)
            #the the loss of G is negative to result of D for competition
            errG = -output.mean()
            #calculate the backward
            errG.backward(retain_graph=True)

            if alpha != 0:
                #define MSE loss
                loss = nn.MSELoss()
                #amplify the z
                Z_opt = opt.noise_amp * z_opt + z_prev
                #use the result generate from Z_opt.detach(),z_prev, calculate the MSE with real, scale with alpha
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)
                #backward
                rec_loss.backward(retain_graph=True)
                #get a number loss
                rec_loss = rec_loss.detach()
            else:  #alpha = 0
                #else get Z as z
                Z_opt = z_opt
                #set the rec_loss = o
                rec_loss = 0

            #update the result
            optimizerG.step()

        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        #if epoch % 25 == 0 or epoch == (opt.niter-1): #replaced by tqdm
        #print(f'scale {len(Gs)}:[{epoch}/{opt.niter}]'

        #if epoch % 500 == 0 or epoch == (opt.niter-1):
        if epoch == (opt.niter - 1):  #only saved once (for small graph)
            #save the fake sample
            plt.imsave(f'{opt.outf}/fake_sample.png',
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            #save the z_opt
            plt.imsave(f'{opt.outf}/G(z_opt).png',
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            #save the model
            torch.save(z_opt, f'{opt.outf}/z_opt.pth')
            #plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            #plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            #plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            #plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            #plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
            #plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)
        #update learning rate
        schedulerD.step()
        schedulerG.step()

    # save the model
    functions.save_networks(netG, netD, z_opt, opt)

    # return the z, in_s(what's this), and generator G
    return z_opt, in_s, netG
コード例 #14
0
def train_single_scale(netD,
                       netG,
                       reals,
                       crops,
                       masks,
                       eye_color,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):

    real_fullsize = reals[len(Gs)]

    crop_size = crops[len(Gs)].size()[2]
    fixed_crop = real_fullsize[:, :, 0:crop_size, 0:crop_size]
    if opt.random_crop:
        real, h_idx, w_idx = functions.random_crop(real_fullsize.clone(),
                                                   crop_size)
    else:
        real = real_fullsize.clone()
    mask = masks[len(Gs)]

    opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer) width
    opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer) height
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                           device=opt.device)
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    eye = functions.generate_eye_mask(opt, masks[-1], opt.stop_scale - len(Gs))

    for epoch in range(opt.niter):

        if opt.resize:
            max_patch_size = int(
                min(real.size()[2],
                    real.size()[3],
                    mask.size()[2] * 1.25))
            min_patch_size = int(max(mask.size()[2] * 0.75, 1))
            patch_size = random.randint(min_patch_size, max_patch_size)
            mask_in = nn.functional.interpolate(mask.clone(), size=patch_size)
            eye_in = nn.functional.interpolate(eye.clone(), size=patch_size)
        else:
            mask_in = mask.clone()
            eye_in = eye.clone()

        eye_colored = eye_in.clone()
        if opt.random_eye_color:
            eye_color = functions.get_eye_color(real)
            eye_colored[:, 0, :, :] *= (eye_color[0] / 255)
            eye_colored[:, 1, :, :] *= (eye_color[1] / 255)
            eye_colored[:, 2, :, :] *= (eye_color[2] / 255)

        if (Gs == []) & (opt.mode != 'SR_train'):
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                             device=opt.device)
            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()
            output = netD(real).to(opt.device)
            #D_real_map = output.detach()
            errD_real = -output.mean()  #-a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j == 0) & (epoch == 0):
                if (Gs == []) & (opt.mode != 'SR_train'):
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev
                    prev = m_image(prev)
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = functions.draw_concat(Gs, Zs, reals, crops, masks,
                                                 eye_colored, NoiseAmp, in_s,
                                                 'rand', m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = functions.draw_concat(Gs, Zs, reals, crops, masks,
                                                   eye_colored, NoiseAmp, in_s,
                                                   'rec', m_noise, m_image,
                                                   opt)
                    criterion = nn.MSELoss()
                    #print(z_prev.get_device())
                    #print(real.get_device())
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = functions.draw_concat(Gs, Zs, reals, crops, masks,
                                             eye_colored, NoiseAmp, in_s,
                                             'rand', m_noise, m_image, opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            # Stacking masks and noise to make input
            G_input = functions.make_input(noise, mask_in, eye_colored)
            fake_background = netG(G_input.detach(), prev)

            import copy
            netG_copy = copy.deepcopy(netG)

            # Cropping mask shape from generated image and putting on top of real image at random location
            fake, fake_ind, eye_ind = functions.gen_fake(
                real, fake_background, mask_in, eye_in, eye_color, opt)

            output = netD(fake.detach())
            errD_fake = output.mean()

            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach())

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):

            netG.zero_grad()
            output = netD(fake)
            #D_fake_map = output.detach()
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                input_opt = functions.make_input(Z_opt, mask_in, eye_in)
                rec_loss = alpha * loss(netG(input_opt.detach(), z_prev), real)
                #rec_loss = alpha*loss(netG(input_opt.detach(),z_prev),fixed_crop)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch % 25 == 0 or epoch == (opt.niter - 1):
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()))
            plt.imsave('%s/fake_indicator.png' % (opt.outf),
                       functions.convert_image_np(fake_ind.detach()))
            plt.imsave('%s/eye_indicator.png' % (opt.outf),
                       functions.convert_image_np(eye_ind.detach()))
            plt.imsave('%s/background.png' % (opt.outf),
                       functions.convert_image_np(fake_background.detach()))
            #plt.imsave('%s/G(z_opt).png'    % (opt.outf),  functions.convert_image_np(netG(input_opt.detach(), z_prev).detach()), vmin=0, vmax=1)
            #plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            #plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            #plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            #plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            #plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
            #plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)

            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

        schedulerD.step()
        schedulerG.step()

        if opt.random_crop:
            real, h_idx, w_idx = functions.random_crop(
                real_fullsize, crop_size)  #randomly find crop in image
        if opt.random_eye:
            eye = functions.generate_eye_mask(opt, masks[-1],
                                              opt.stop_scale - len(Gs))

    functions.save_networks(netG, netD, z_opt, opt)

    if len(Gs) == (opt.stop_scale):
        netG = netG_copy

    return z_opt, in_s, netG
コード例 #15
0
ファイル: manipulate.py プロジェクト: sieunhanbom04/SinGAN
def SinGAN_denoise(Gs,
                   Zs,
                   reals,
                   NoiseAmp,
                   opt,
                   in_s=None,
                   scale_v=1,
                   scale_h=1,
                   n=0,
                   gen_start_scale=0,
                   num_samples=1):
    #if torch.is_tensor(in_s) == False:
    if in_s is None:
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    images_cur = []
    for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp):
        pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2
        m = nn.ZeroPad2d(int(pad1))
        nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v
        nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h

        images_prev = images_cur
        images_cur = []

        for i in range(0, num_samples, 1):
            if n == 0:
                z_curr = functions.generate_noise([1, nzx, nzy],
                                                  device=opt.device)
                z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3])
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z, nzx, nzy],
                                                  device=opt.device)
                z_curr = m(z_curr)

            if images_prev == []:
                I_prev = m(in_s)
                #I_prev = m(I_prev)
                #I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                #I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)
                if opt.mode != "SR":
                    I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]),
                                    0:round(scale_h * reals[n].shape[3])]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]]
                    I_prev = functions.upsampling(I_prev, z_curr.shape[2],
                                                  z_curr.shape[3])
                else:
                    I_prev = m(I_prev)

            if n < gen_start_scale:
                z_curr = Z_opt
            z_curr = Z_opt

            z_in = noise_amp * (z_curr) + I_prev
            I_curr = G(z_in.detach(), I_prev)

            images_cur.append(I_curr)
        n += 1
    return I_curr.detach()
コード例 #16
0
ファイル: training.py プロジェクト: SherlockHolmes9102/sinGAN
def train_single_scale(netD,
                       netG,
                       reals,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):
    # print("Gs:", Gs)
    # Gs:scale尺度
    print("len(Gs):", len(Gs))

    # 获取当前scale的真实值
    real = reals[len(Gs)]

    opt.nzx = real.shape[2]  # +(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  # +(opt.ker_size-1)*(opt.num_layer)
    # 接受野
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    print(opt.receptive_field)  # out: 3+2*4*1 = 11
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    print(pad_noise)  # pad_noise: 5
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    # ZeroPad2d 在输入的数据周围做zero-padding
    m_noise = nn.ZeroPad2d(int(pad_noise))
    print('m_noise', m_noise)  # ZeroPad2d(padding=(5, 5, 5, 5), value=0.0)
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha
    print(alpha)

    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy])
    # print('fixed_noise', fixed_noise.shape)  #  torch.Size([1, 3, 76, 76])
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # 设置优化器
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    # 绘画损失列表
    list_ssim = []
    list_ssim_1 = []
    list_ssim_2 = []
    list_ssim_3 = []
    list_ssim_4 = []

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    # 循环迭代
    for epoch in range(opt.niter):
        schedulerD.step()
        schedulerG.step()
        # 与运算
        if (Gs == []) & (opt.mode != 'SR_train'):
            # opt.nzx和opt.nzy是当前scale的尺寸
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy])
            # 扩充维度为(1, 3, opt.nzx, opt.nzy)
            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy])
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy])
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        # 更新判别器
        ###########################

        # Dsteps = 3
        for j in range(opt.Dsteps):
            # train with real 用真实图像训练
            netD.zero_grad()

            output = netD(real).to(opt.device)
            # print(netD)
            # print('output', output)  # 4维数据

            errD_real = -output.mean()  # -a
            # print('errD_real', errD_real) # -2.
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake 用虚假图像训练
            # 仅第一次训练用到z_prev(噪声)
            if (j == 0) & (epoch == 0):
                if (Gs == []) & (opt.mode != 'SR_train'):
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev
                    prev = m_image(prev)

                    # nc_z  3通道
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    # print('z_prev', z_prev)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    # MSE 军方误差损失函数
                    criterion = nn.MSELoss()

                    # 均方根误差, 标准误差
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    # 标准误差
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)
                prev = m_image(prev)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            # .detach()用于切断反向传播
            fake = netG(noise.detach(), prev)

            # 添加的ssim_loss
            ssim_loss = ssim(real, fake, data_range=255, size_average=True)

            output = netD(fake.detach())
            # 判别以后损失反向传播
            errD_fake = output.mean()
            # print('errD_fake', errD_fake)  # -0.0072
            errD_fake.backward(retain_graph=True)
            # 判别器_生成器_噪声
            D_G_z = output.mean().item()
            # print('D_G_z', D_G_z)

            # 梯度惩罚---------------------------------------------------------------------------
            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad)
            # print('gradient_penalty', gradient_penalty)
            # 梯度惩罚更新
            gradient_penalty.backward()

            # 损失函数:对抗损失 + 重构损失
            D_ssim_3 = errD_real + errD_fake + gradient_penalty
            errD = errD_real + errD_fake + gradient_penalty

            # print('item', errD.item())

            # D_ssim = 0.8 * (errD_real + errD_fake) + 0.68 * gradient_penalty + (1 - ssim_loss)
            # D_ssim_1 = 0.7 * (errD_real + errD_fake) + 0.6 * gradient_penalty + 1.2 * (1 - ssim_loss)
            # D_ssim_2 = 0.75 * (errD_real + errD_fake) + 0.6 * gradient_penalty + 1.2 * (1 - ssim_loss)
            # errD = (errD_real + errD_fake) + 0.5 * gradient_penalty + 1.4 * (1 - ssim_loss)
            # D_ssim_4 = 0.6 * (errD_real + errD_fake) + 0.5 * gradient_penalty + 1.4 * (1 - ssim_loss)

            # int_ssim = D_ssim.item()
            # int_ssim = round(int_ssim, 4)

            # int_ssim_1 = D_ssim_1.item()
            # int_ssim_1 = round(int_ssim_1, 4)
            #
            # int_ssim_2 = D_ssim_2.item()
            # int_ssim_2 = round(int_ssim_2, 4)

            int_ssim_3 = D_ssim_3.item()
            int_ssim_3 = round(int_ssim_3, 4)

            optimizerD.step()

        errDint = []
        errD2plot.append(errD.detach())

        # print('errD2plot', errD2plot)
        for i in range(len(errD2plot)):
            errDint.append(errD2plot[i].cpu().numpy())

        # list_ssim.append(int_ssim)
        # list_ssim_1.append(int_ssim_1)
        # list_ssim_2.append(int_ssim_2)
        # list_ssim_3.append(int_ssim_3)
        # list_ssim_4.append(int_ssim_4)
        # print('list_ssim', list_ssim)

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            # D_fake_map = output.detach()

            # errG均值函数
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)

                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errGint = []
        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        for i in range(len(errG2plot)):
            errGint.append(errG2plot[i].cpu().numpy())

        if epoch % 100 == 0 or epoch == (opt.niter - 1):
            # len(Gs):scale   epoch= ,   niter = 2000
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            # plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            # plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            # plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            # plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            plt.imsave('%s/noise.png' % (opt.outf),
                       functions.convert_image_np(noise),
                       vmin=0,
                       vmax=1)
            # plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)

            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

    # 目前是训练单次scale, 绘制第一次scale
    while (len(Gs) == 0):
        name = 'G_loos & G_loos'
        functions.plot_learning_curves(errGint, errDint, opt.niter,
                                       'Generator', 'Discriminator', name)
        break

    while (len(Gs) == 1):
        name = 'G_loos & G_loos_1'
        functions.plot_learning_curves_1(errGint, errDint, opt.niter,
                                         'Generator', 'Discriminator', name)
        break

    # while (len(Gs) == 2):
    #     name = 'G_loos & G_loos_2'
    #     functions.plot_learning_curves(errGint, errDint,
    #                                    list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD',
    #                                    'ssim_3', 'ssim_4', name)
    #     break

    # while (len(Gs) == 3):
    #     name = 'G_loos & G_loos_3'
    #     functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2,
    #                                    list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim',
    #                                    'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name)
    #     break

    # while (len(Gs) == 4):
    #     name = 'G_loos & G_loos_4'
    #     functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2,
    #                                    list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim',
    #                                    'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name)
    #     break
    #
    # while (len(Gs) == 5):
    #     name = 'G_loos & G_loos_5'
    #     functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2,
    #                                    list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim',
    #                                    'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name)
    #     break
    #
    # while (len(Gs) == 6):
    #     name = 'G_loos & G_loos_6'
    #     functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2,
    #                                    list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim',
    #                                    'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name)
    #     break
    #
    # while (len(Gs) == 7):
    #     name = 'G_loos & G_loos_7'
    #     functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2,
    #                                    list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim',
    #                                    'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name)
    #     break
    #
    while (len(Gs) == 8):
        name = 'G_loos & G_loos_8'
        functions.plot_learning_curves_8(errGint, errDint, opt.niter,
                                         'Generator', 'Discriminator', name)
        break

    functions.save_networks(netG, netD, z_opt, opt)

    return z_opt, in_s, netG
コード例 #17
0
def train_single_scale(netD,netG,reals,Gs,Zs,in_s,NoiseAmp,opt,scale_num, netG_optimizer, netD_optimizer):  

    real = reals[len(Gs)]
    opt.nzx = real.shape[1]#+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[2]#+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size-1)*(opt.num_layer-1))*opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)  
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    m_noise = tf.keras.layers.ZeroPadding2D(padding=(int(pad_noise), int(pad_noise)))
    m_image = tf.keras.layers.ZeroPadding2D(padding=(int(pad_image), int(pad_image)))
    alpha = opt.alpha

    fixed_noise = functions.generate_noise([opt.nc_z,opt.nzx,opt.nzy])
    z_opt = tf.zeros_like(fixed_noise)
    z_opt = m_noise(z_opt)
    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []
        
    for epoch in range(opt.niter):

        with tf.GradientTape(persistent=True) as netD_tape, tf.GradientTape(persistent=True) as netG_tape:

            if (Gs == []) & (opt.mode != 'SR_train'):
                z_opt = functions.generate_noise([1,opt.nzx,opt.nzy]) # (1,33,25)
                z_opt = tf.broadcast_to(z_opt, [1, z_opt.shape[1], z_opt.shape[2], 3]) # (1,33,25,3)
                z_opt = m_noise(z_opt)
                noise_ = functions.generate_noise([1,opt.nzx,opt.nzy]) 
                noise_ = tf.broadcast_to(noise_, [1, noise_.shape[1], noise_.shape[2], 3])    
                noise_ = m_noise(noise_)
            else:
                noise_ = functions.generate_noise([opt.nc_z,opt.nzx,opt.nzy])
                noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################


            for j in range(opt.Dsteps):  
                real = reals[len(Gs)]
                output_real = netD(real) 
                errD_real = -tf.reduce_mean(output_real)
                D_x = float(-errD_real.numpy()) # Conversion to numpy required
                # train with fake
                if (j==0) & (epoch == 0): 
                    if (Gs == []) & (opt.mode != 'SR_train'):
                        prev = tf.zeros([1,opt.nzx,opt.nzy,opt.nc_z])
                        in_s = prev
                        prev = m_image(prev)
                        z_prev = tf.zeros([1, opt.nzx, opt.nzy, opt.nc_z])
                        z_prev = m_noise(z_prev)
                        opt.noise_amp = 1
                    else:
                        prev = draw_concat(Gs,Zs,reals,NoiseAmp,in_s,'rand',m_noise,m_image,opt)
                        prev = m_image(prev)
                        z_prev = draw_concat(Gs,Zs,reals,NoiseAmp,in_s,'rec',m_noise,m_image,opt)
                        RMSE = tf.sqrt(tf.reduce_mean(tf.square(tf.subtract(real, z_prev))))
                        opt.noise_amp = opt.noise_amp_init*RMSE
                        z_prev = m_image(z_prev)
                else:
                    prev = draw_concat(Gs,Zs,reals,NoiseAmp,in_s,'rand',m_noise,m_image,opt)
                    prev = m_image(prev)
                if (Gs == []) & (opt.mode != 'SR_train'):
                    noise = noise_
                else:
                    noise = opt.noise_amp*noise_+prev
                fake = netG(noise, prev, training=True)
                output_fake = netD(fake)
                errD_fake = tf.reduce_mean(output_fake)
                D_G_z = float(output_fake.numpy().mean())
                fake_gp = functions.fake_gp_generator(real, fake)
        
                with tf.GradientTape() as gp_tape:
                    gp_tape.watch(fake_gp)
                    gp_D_src = netD(fake_gp) 
                gp_D_grad = gp_tape.gradient(gp_D_src, fake_gp)                  
                gp = opt.lambda_grad*tf.reduce_mean(((tf.norm(gp_D_grad, ord=2, axis=3)-1.0)**2))

                errD = errD_real + errD_fake + gp
                print('errD_real:', errD_real)
                print('errD_fake:', errD_fake)
                netD_gradients = netD_tape.gradient(errD, netD.trainable_variables)
                netD_optimizer.apply_gradients(zip(netD_gradients, netD.trainable_variables))
            for j in range(opt.Gsteps):
                errG_fake = -tf.reduce_mean(output_fake)
                if alpha!=0:
                    Z_opt = opt.noise_amp*z_opt+z_prev
                    rec_loss = alpha * tf.reduce_mean(tf.square(tf.subtract(netG(Z_opt, z_prev, training=True), real)))
                else:
                    Z_opt = z_opt
                    rec_loss = 0
                errG = errG_fake + rec_loss
                print('errG_fake:', errG_fake)
                print('rec_loss:', rec_loss)
                netG_gradients = netG_tape.gradient(errG, netG.trainable_variables) # 오직 netG만 update! 
                netG_optimizer.apply_gradients(zip(netG_gradients, netG.trainable_variables))
        del netG_tape, netD_tape
    
        errG2plot.append(errG+rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch % 1 == 0 or epoch == (opt.niter-1):
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter-1):
            plt.imsave('%s/fake_sample.png' %  (opt.outf), functions.convert_image_np(fake))
            plt.imsave('%s/G(z_opt).png'    % (opt.outf),  functions.convert_image_np(netG(Z_opt, z_prev, training=False)))
    functions.save_networks(netD,netG,z_opt,opt,scale_num) # single scale training 끝날때 마다 저장 함 
 
    return z_opt,in_s,netG   
コード例 #18
0
ファイル: backend.py プロジェクト: riven314/Streamlit-SinGAN
# 3. Generate GIFs (varying beta && start_scale)
in_s = torch.full(Zs[0].shape, 0, device=device)
images_cur = []
count = 0

for G, Z_opt, noise_amp, real in zip(Gs, Zs, NoiseAmp, reals):
    pad_image = int(((ker_size - 1) * num_layer) / 2)  # what it means??
    nzx = Z_opt.shape[2]
    nzy = Z_opt.shape[3]
    m_image = nn.ZeroPad2d(int(pad_image))
    images_prev = images_cur
    images_cur = []
    if count == 0:
        # z_rand is gaussian noise
        z_rand = functions.generate_noise([1, nzx, nzy], device=device)
        z_rand = z_rand.expand(1, 3, Z_opt.shape[2], Z_opt.shape[3])
        z_prev1 = 0.95 * Z_opt + 0.05 * z_rand
        z_prev2 = Z_opt
    else:
        z_prev1 = 0.95 * Z_opt + 0.05 * functions.generate_noise(
            [nc_z, nzx, nzy], device=device)
        z_prev2 = Z_opt

    for i in range(0, 100, 1):
        if count == 0:
            z_rand = functions.generate_noise([1, nzx, nzy], device=device)
            # make z_rand same across channels
            z_rand = z_rand.expand(1, 3, Z_opt.shape[2], Z_opt.shape[3])
            diff_curr = beta * (z_prev1 - z_prev2) + (1 - beta) * z_rand
        else:
コード例 #19
0
ファイル: training.py プロジェクト: tboen1/MESIGAN
def train_single_scale(netD,
                       netG,
                       reals,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):

    real = reals[len(Gs)]
    opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                           device=opt.device)
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    for epoch in range(opt.niter):

        if (Gs == []) & (opt.mode != 'SR_train'):
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                             device=opt.device)
            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()

            output = netD(real).to(opt.device)
            #D_real_map = output.detach()
            errD_real = -output.mean()  #-a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j == 0) & (epoch == 0):
                if (Gs == []) & (opt.mode != 'SR_train'):
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev
                    prev = m_image(prev)
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            fake = netG(noise.detach(), prev)
            output = netD(fake.detach())
            errD_fake = output.mean()
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach())

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            #D_fake_map = output.detach()
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch % 25 == 0 or epoch == (opt.niter - 1):

            stamp = datetime.datetime.now()
            timestamp.append(stamp)

            delta = (timestamp[-1] - timestamp[-2]).seconds
            mbs, percent = memory_check()
            print('scale %d:[%d/%d] | Mb: %.3f | Percent %.3f | secs %.3f ' %
                  (len(Gs), epoch, opt.niter, mbs, 100 * percent, delta))
            full_memory.append([mbs, percent])
            full_time.append(delta)

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            #plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            #plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            #plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            #plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            #plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
            #plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)

            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

        schedulerD.step()
        schedulerG.step()

    functions.save_networks(netG, netD, z_opt, opt)

    nvidia_smi.nvmlInit()
    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
    # card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate

    mem_res = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
    mbs = mem_res.used / (1024**2)
    percent = mem_res.used / mem_res.total
    print(f'mem: {mem_res.used / (1024**2)} (GiB)')  # usage in GiB
    print(f'mem: {100 * (mem_res.used / mem_res.total):.6f}%')  # percentage
    return z_opt, in_s, netG, mbs, percent
コード例 #20
0
def train_single_scale(netD,
                       netG,
                       reals,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):

    real = reals[len(Gs)]
    opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy])
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    for epoch in range(opt.niter):
        schedulerD.step()
        schedulerG.step()
        if (Gs == []) & (opt.mode != 'SR_train'):
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy])
            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy])
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy])
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()

            output = netD(real).to(opt.device)
            #D_real_map = output.detach()
            errD_real = -output.mean()  #-a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j == 0) & (epoch == 0):
                if (Gs == []) & (opt.mode != 'SR_train'):
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev
                    prev = m_image(prev)
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            fake = netG(noise.detach(), prev)
            output = netD(fake.detach())
            errD_fake = output.mean()
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach())

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            #D_fake_map = output.detach()
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch % 25 == 0 or epoch == (opt.niter - 1):
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            #plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            #plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            #plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            #plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            #plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
            #plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)

            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))
    functions.save_networks(netG, netD, z_opt, opt)
    return z_opt, in_s, netG
コード例 #21
0
ファイル: manipulate.py プロジェクト: codeconomics/SinGAN
def SinGAN_generate(Gs,
                    Zs,
                    reals,
                    NoiseAmp,
                    opt,
                    in_s=None,
                    scale_v=1,
                    scale_h=1,
                    n=0,
                    gen_start_scale=0,
                    num_samples=50,
                    output_image=False):
    #if torch.is_tensor(in_s) == False:
    if in_s is None:
        # make in_s a 0 tensor with reals[0] shape
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    images_cur = []
    #for each layers
    for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp):
        #generate a pad class with width ((ker_size-1)*num_layer)/2
        pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2
        m = nn.ZeroPad2d(int(pad1))

        #the shape inside padding * scale
        nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v
        nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h

        #get all the previsous image
        images_prev = images_cur
        images_cur = []
        output_list = []
        #for the number of samples
        for i in range(0, num_samples, 1):
            if n == 0:
                #generate the noise
                z_curr = functions.generate_noise([1, nzx, nzy],
                                                  device=opt.device)
                #broadcast to the correct shape
                z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3])
                #padding it
                z_curr = m(z_curr)
            else:
                #generate noise with defined shape
                z_curr = functions.generate_noise([opt.nc_z, nzx, nzy],
                                                  device=opt.device)
                #padding
                z_curr = m(z_curr)
            #if it's the first scale
            if images_prev == []:
                #use in_s as the first one
                I_prev = m(in_s)
                #I_prev = m(I_prev)
                #I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                #I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
            else:
                #get the last image
                I_prev = images_prev[i]
                #resize it by 1/scale_factor
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)
                # cut a piece of shape (round(scale_v * reals[n].shape[2] * round(scale_h * reals[n].shape[3]))
                I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]),
                                0:round(scale_h * reals[n].shape[3])]
                #padding
                I_prev = m(I_prev)
                #cut a piece of shape (z_curr.shape[2], z_curr.shape[3])
                I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]]
                #upsample this piece to original shape, with bilinear policy
                I_prev = functions.upsampling(I_prev, z_curr.shape[2],
                                              z_curr.shape[3])

            # amplify the z by the param, add the previous graph
            z_in = noise_amp * (z_curr) + I_prev

            # pass this value and previous graph to generator, get the value
            I_curr = G(z_in.detach(), I_prev)

            #for the last loop
            if n == len(reals) - 1:
                #generate the directory
                dir2save = functions.generate_dir2save(opt)  #modified
                try:
                    os.makedirs(dir2save)
                except OSError:
                    pass
                # new variable
                if (output_image):
                    #save the new generated image
                    plt.imsave(f'{dir2save}/{i}.png',
                               functions.convert_image_np(I_curr.detach()),
                               vmin=0,
                               vmax=1)
                # have the generated image into the list
                output_list.append(functions.convert_image_np(I_curr.detach()))
            images_cur.append(I_curr)
        n += 1
    return I_curr.detach(), output_list  #newly added
コード例 #22
0
def SinGAN_generate(Gs,Zs,reals,NoiseAmp,opt,in_s=None,scale_v=1,scale_h=1,n=0,gen_start_scale=0,num_samples=50):
    #if torch.is_tensor(in_s) == False:
    passes = 0
    if in_s is None:
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    images_cur = []
    for G,Z_opt,noise_amp in zip(Gs,Zs,NoiseAmp):
        pad1 = ((opt.ker_size-1)*opt.num_layer)/2
        m = nn.ZeroPad2d(int(pad1))
        nzx = (Z_opt.shape[2]-pad1*2)*scale_v
        nzy = (Z_opt.shape[3]-pad1*2)*scale_h

        images_prev = images_cur
        images_cur = []

        for i in range(0,num_samples,1):
            if n == 0:
                z_curr = functions.generate_noise([1,nzx,nzy], device=opt.device)
                z_curr = z_curr.expand(1,3,z_curr.shape[2],z_curr.shape[3])
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device)
                z_curr = m(z_curr)
            
                
            if images_prev == []:
                print("in_s shape before padding with m", in_s.shape)
                I_prev = m(in_s)
                print("in_s shape after padding with m now I_prev", in_s.shape)
#                I_prev = m(I_prev)
#                I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
                print("I_prev shape after upsampling using noise shape", I_prev.shape)
            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev,1/opt.scale_factor, opt)
                if opt.mode != "SR":
                    I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                    
                    I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
                    
                else:
                    I_prev = m(I_prev)

            if n < gen_start_scale:
                z_curr = Z_opt
                
#                real = img.imread("D:\MVA\CompVision\Project\SinGAN-master\Input\Images/Salt_and_Pepper_Golden_Bridge_by_night (2).png")
                real = img.imread("D:\MVA\CompVision\Project\SinGAN-master\Input\Images/Noisy_Golden_Bridge_by_night.jpg")
#                real = img.imread("D:\MVA\CompVision\Project\SinGAN-master\Output\RandomSamples\Golden_Bridge_by_night/1.png")
                real = real[:,:,:,None]
                real = real.transpose((3,2,0,1))/255
                real = torch.from_numpy(real)
                real = move_to_gpu(real)
                real = real.type(torch.cuda.FloatTensor)
                real = ((real - 0.5)*2).clamp(-1,1)
                real = real[:,0:3,:,:]
                
#                real = imresize(real,1/opt.scale_factor, opt)
                
#                real = real[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])]
                real = m(real)
#                real = real[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                
                I_prev = functions.upsampling(real,z_curr.shape[2],z_curr.shape[3])

               
            print("I_prev",I_prev.shape)
            print("z_curr",z_curr.shape)
            print('---')
#            z_in = noise_amp*(z_curr)+I_prev
            z_in = 0*(z_curr)+I_prev
            I_curr = G(z_in.detach(),I_prev)

            if n == len(reals)-1:
                if opt.mode == 'train':
                    dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale)
                else:
                    dir2save = functions.generate_dir2save(opt)
                try:
                    os.makedirs(dir2save)
                except OSError:
                    pass
                if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"):
                    plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1)
#                    plt.imsave('%s/%d.png' % (dir2save, passes), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1)
#                    passes +=1 
                    #plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1)
                    #plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1)
            images_cur.append(I_curr)
        n+=1
    
#    plt.imsave('D:\MVA\CompVision\Project\SinGAN-master\Output\RandomSamples\Golden_Bridge_by_night\gen_start_scale=0\Denoised.png', functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1)
    return I_curr.detach()
コード例 #23
0
def train_single_scale(netD,
                       netG,
                       reals,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):
    """This is the core to understand it all. """
    real = reals[len(Gs)]  # real image at current scale
    opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':  # Supplementary says they generate noise on the border in animation mode
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy
                                            ])  # select a fixed noise at start
    z_opt = torch.full(
        fixed_noise.shape, 0,
        device=opt.device)  # but they didn't use it, just used all 0 instead.
    z_opt = m_noise(z_opt)  # get 0 padded (still all 0)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    errD2plot = []  # collect the error in D training
    errG2plot = []  # collect the error in G training
    D_real2plot = []  # collect D_x error for real image
    D_fake2plot = []  # Discriminator error for fake image
    z_opt2plot = []  # collect reconstruction error

    for epoch in range(opt.niter):
        if (Gs == []) & (opt.mode != 'SR_train'):  # if it's the first scale.
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                             device=opt.device)
            z_opt = m_noise(z_opt.expand(
                1, 3, opt.nzx, opt.nzy))  # this is chosen start of each epoch
            noise_ = functions.generate_noise(
                [1, opt.nzx, opt.nzy],
                device=opt.device)  # single channel noise
            noise_ = m_noise(
                noise_.expand(1, 3, opt.nzx, opt.nzy)
            )  # expand is like repeat, it copy data along channel axis. So noise in RGB channels share the same val
        else:  # for each epocs only one noise_ is used! why?
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):  # a few D step first
            # train with real
            netD.zero_grad()

            output = netD(real).to(
                opt.device)  # Output of netD, the mean of it is the score!
            #D_real_map = output.detach()
            errD_real = -output.mean(
            )  #-a # want to maximize D output for patches in real img.
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()  # D loss for the real image!

            # train with fake, need to generate through the previous Generators
            if (j == 0) & (epoch == 0):  # first Dstep in this level (epoch 0)
                if (Gs == []) & (opt.mode != 'SR_train'):  # initial scale
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev  # in_s doesn't get padded!
                    prev = m_image(prev)  # prev gets padded!
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(
                        real, z_prev))  # MSE between real and z_prev
                    opt.noise_amp = opt.noise_amp_init * RMSE  # learn the noise amplitude
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)  # process the in_s !
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):  # top level
                noise = noise_  # now, full 0 tersor
            else:  # other level
                noise = opt.noise_amp * noise_ + prev  # now, full 0 tersor still
            # generate a single fake image through G and pass to D
            fake = netG(
                noise.detach(), prev
            )  # netG takes 2 inputs prev and noise, net process noise and + prev
            output = netD(
                fake.detach()
            )  # netD score the fake image, note they detach here, so error not back prop to G
            errD_fake = output.mean()  # decrease the D output for fake.
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach().item()
                         )  # loss combining D loss for real, fake and gradient

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):  # then a few G steps
            netG.zero_grad()
            output = netD(
                fake)  # note here the same fake image is used multi times!???
            #D_fake_map = output.detach()
            errG = -output.mean(
            )  # the D, adversarial loss part. here want to minimize adversarial loss. (Fake the img)
            errG.backward(retain_graph=True)  # Why? retain_graph
            if alpha != 0:  # compute the reconstruction loss is alpha non-zero
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(
                        z_prev, centers
                    )  # z_prev here are all inherited from D training part
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)
                rec_loss.backward(retain_graph=True)  # Why? retain_graph
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach().item() + rec_loss.item())
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss.item())

        if epoch % 25 == 0 or epoch == (opt.niter - 1):
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave(
                '%s/noise.png' % (opt.outf),
                functions.convert_image_np(noise),
                vmin=0,
                vmax=1
            )  # this is the noise that go into generator, so prev + amp * noise_
            plt.imsave('%s/Z_opt.png' % (opt.outf),
                       functions.convert_image_np(Z_opt),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/z_prev.png' % (opt.outf),
                       functions.convert_image_np(z_prev),
                       vmin=0,
                       vmax=1)
            #plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            #plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            #plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            #plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            #plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
            #plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)

            torch.save(
                {
                    "errD": errD2plot,
                    "errG": errG2plot,
                    "D_real": D_real2plot,
                    "D_fake": D_fake2plot,
                    "recons": z_opt2plot
                }, '%s/loss_trace.pth' % (opt.outf))
            plt.figure()
            plt.plot(errD2plot, label="errD")
            plt.plot(errG2plot, label="errG")
            plt.plot(D_real2plot, label="Dreal")
            plt.plot(D_fake2plot, label="Dfake")
            plt.plot(z_opt2plot, label="recons")
            plt.legend()
            plt.savefig('%s/loss_trace.png' % (opt.outf))
            plt.close()
            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

        schedulerD.step()
        schedulerG.step()

    functions.save_networks(netG, netD, z_opt, opt)
    return z_opt, in_s, netG
コード例 #24
0
def SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt, modification=None, in_s=None, scale_v=1, scale_h=1, n=0,
                    gen_start_scale=0, num_samples=10):
    # start_scale = here we manipulate the image
    # func stylize

    # if torch.is_tensor(in_s) == False:
    if in_s is None:
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    images_cur = []
    for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp):
        pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2
        m = nn.ZeroPad2d(int(pad1))
        nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v
        nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h

        images_prev = images_cur
        images_cur = []

        for i in range(0, num_samples, 1):
            if n == 0:
                z_curr = functions.generate_noise([1, nzx, nzy], device=opt.device)
                z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3])
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z, nzx, nzy], device=opt.device)
                z_curr = m(z_curr)

            if images_prev == []:
                I_prev = m(in_s)
                # I_prev = m(I_prev)
                # I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                # I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)
                if opt.mode != "SR":
                    I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]]
                    I_prev = functions.upsampling(I_prev, z_curr.shape[2], z_curr.shape[3])
                else:
                    I_prev = m(I_prev)

            if n < gen_start_scale:
                z_curr = Z_opt

            z_in = noise_amp * (z_curr) + I_prev

            dir2save = '%s/RandomSamples/%s/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], modification, gen_start_scale)

            try:
                os.makedirs(dir2save)
            except OSError:
                pass

            if n==gen_start_scale:
                plt.imsave('%s/%d_before_modification.png' % (dir2save, i), functions.convert_image_np(z_in.detach()), vmin=0,vmax=1)

            # ##################################### Image modification #################################################
            #TODO if you want the modification to happen only once, change the >= into ==
            #TODO at the moment, modification happens at every scale from the gen_start_scale and above, unless no
            #TODO modification is specificed
            #TODO The modified image is saved only at the generation scale
            #TODO when using blending, consider trying different blending options and opcity. These can be modified
            #TODO within the modify_input_to_generator function below
            if (n >= gen_start_scale) & (modification is not None):
                shape = z_in.shape
                cont_in = preprocess_content_image(opt, reals,n)
                z_in = modify_input_to_generator(z_in, cont_in, modification, opacity=1)
                assert shape == z_in.shape
                if n==gen_start_scale:
                    plt.imsave('%s/%d_after_modification.png' % (dir2save, i), functions.convert_image_np(z_in.detach()), vmin=0,vmax=1)
            # ################################## End of image modification #############################################
            I_curr = G(z_in.detach(), I_prev)

            if n == len(reals) - 1:
                if opt.mode == 'train':
                    dir2save = '%s/RandomSamples/%s/%s/gen_start_scale=%d' % (
                    opt.out, opt.input_name[:-4], modification, gen_start_scale)
                else:
                    #dir2save = functions.generate_dir2save(opt)
                    dir2save = '%s/RandomSamples/%s/%s/gen_start_scale=%d' % (
                    opt.out, opt.input_name[:-4], modification, gen_start_scale)
                try:
                    os.makedirs(dir2save)
                except OSError:
                    pass
                if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (
                        opt.mode != "paint2image"):
                    plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1)
                    # plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1)
                    # plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1)
            images_cur.append(I_curr)
        n += 1
    return I_curr.detach()
コード例 #25
0
def train_single_scale(netD, netD_mask1, netD_mask2,netG,reals1, reals2, Gs,Zs,in_s1, in_s2,NoiseAmp,opt):

    real1 = reals1[len(Gs)]
    real2 = reals2[len(Gs)]

    if opt.replace_background:
        background_real1 = create_background(functions.convert_image_np(real1))
        real2 = create_img_over_background(functions.convert_image_np(real2), background_real1)

        plt.imsave('%s/background_real_scale1.png' % (opt.outf), background_real1, vmin=0, vmax=1)
        plt.imsave('%s/real_scale2_new.png' % (opt.outf), real2, vmin=0, vmax=1)

        real2 = functions.np2torch(real2, opt)

    # assumption: the images are the same size
    opt.nzx = real1.shape[2]#+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real1.shape[3]#+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size-1)*(opt.num_layer-1))*opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)

    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    fixed_noise = functions.generate_noise([opt.nc_z,opt.nzx,opt.nzy],device=opt.device)
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    l1_loss = nn.L1Loss()
    zero_mask_tensor = torch.zeros([1,opt.nzx,opt.nzy], device=opt.device)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay_d)
    optimizerD_masked1 = optim.Adam(netD_mask1.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay_d_mask1)
    optimizerD_masked2 = optim.Adam(netD_mask2.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay_d_mask2)
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999))

    if opt.cyclic_lr:
        schedulerD = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizerD, base_lr=opt.lr_d*opt.gamma, max_lr=opt.lr_d,
                                                       step_size_up=opt.niter/10, mode="triangular2",
                                                       cycle_momentum=False)
        schedulerD_masked1 = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizerD_masked1, base_lr=opt.lr_d*opt.gamma, max_lr=opt.lr_d,
                                                               step_size_up=opt.niter/10, mode="triangular2",
                                                               cycle_momentum=False)
        schedulerD__masked2 = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizerD_masked2, base_lr=opt.lr_d*opt.gamma, max_lr=opt.lr_d,
                                                                step_size_up=opt.niter/10, mode="triangular2",
                                                                cycle_momentum=False)
        schedulerG = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizerG, base_lr=opt.lr_d*opt.gamma, max_lr=opt.lr_d,
                                                       step_size_up=opt.niter/10, mode="triangular2",
                                                       cycle_momentum=False)
    else:
        schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,milestones=[1600],gamma=opt.gamma)
        schedulerD_masked1 = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD_masked1, milestones=[1600], gamma=opt.gamma)
        schedulerD__masked2 = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD_masked2, milestones=[1600], gamma=opt.gamma)
        schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,milestones=[1600],gamma=opt.gamma)

    discriminators = [netD, netD_mask1, netD_mask2]
    discriminators_optimizers = [optimizerD, optimizerD_masked1, optimizerD_masked2]
    discriminators_schedulers = [schedulerD, schedulerD_masked1, schedulerD__masked2]

    err_D_img1_2plot = []
    err_D_img2_2plot = []
    err_D_mask1_2plot = []
    err_D_mask2_2plot = []
    errG_total_loss_2plot = []
    errG_total_loss1_2plot = []
    errG_total_loss2_2plot = []
    errG_fake1_2plot = []
    errG_fake2_2plot = []
    D1_real2plot = []
    D2_real2plot = []
    D1_fake2plot = []
    D2_fake2plot = []
    l1_mask_loss2plot = []
    mask_loss2plot = []
    reconstruction_loss1_2plot = []
    reconstruction_loss2_2plot = []

    for epoch in range(opt.niter):
        """
        We want to ensure that there exists a specific set of input noise maps, which generates the original image x.
        We specifically choose {z*, 0, 0, ..., 0}, where z* is some fixed noise map.
        In the first scale, we create that z* (aka z_opt). On other scales, z_opt is just zeros (initialized above)
        """
        is_first_scale = len(Gs) == 0

        noise1_, z_opt1 = _create_noise_for_iteration(is_first_scale, m_noise, opt, z_opt, NoiseMode.Z1)
        noise2_, z_opt2 = _create_noise_for_iteration(is_first_scale, m_noise, opt, z_opt, NoiseMode.Z2)

        ############################
        # (1) Update D networks:
        # - netD: train with real on 2 input images (real1, real2)
        # - netD: train with fake on 2 fake images from different noise source (NoiseMode.Z1, NoiseMode.Z2)
        # if opt.enable_mask is ON, then:
        # - netD_mask1: train with real on real1
        # - netD_mask2: train with real on real2
        # - netD_mask1: train with fake on the generated fake image with mask1 applied on it
        # - netD_mask2: train with fake on the generated fake image with mask2 applied on it
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            for discriminator in discriminators:
                discriminator.zero_grad()

            errD_real1, D_x1 = discriminator_train_with_real(netD, opt, real1)
            errD_real2, D_x2 = discriminator_train_with_real(netD, opt, real2)

            # single discriminator for each image
            if opt.enable_mask:
                errD_mask1_real1, _ = discriminator_train_with_real(netD_mask1, opt, real1)
                errD_mask2_real2, _ = discriminator_train_with_real(netD_mask2, opt, real2)

            # train with fake
            in_s1, noise1, prev1, new_z_prev1 = _prepare_discriminator_train_with_fake_input(Gs, NoiseAmp, Zs, epoch,
                                                                                             in_s1, is_first_scale, j,
                                                                                             m_image, m_noise, noise1_,
                                                                                             opt, real1, reals1,
                                                                                             NoiseMode.Z1)
            in_s2, noise2, prev2, new_z_prev2 = _prepare_discriminator_train_with_fake_input(Gs, NoiseAmp, Zs, epoch,
                                                                                             in_s2, is_first_scale, j,
                                                                                             m_image, m_noise, noise2_,
                                                                                             opt, real2, reals2,
                                                                                             NoiseMode.Z2)
            if new_z_prev1 is not None:
                z_prev1 = new_z_prev1
            if new_z_prev2 is not None:
                z_prev2 = new_z_prev2

            # Z1 only:
            mixed_noise1 = functions.merge_noise_vectors(noise1, torch.zeros(noise1.shape, device=opt.device),
                                               opt.noise_vectors_merge_method)

            fake1_output = _generate_fake(netG, mixed_noise1, prev1)

            if opt.enable_mask:
                fake1, fake1_mask1, fake1_mask2 = fake1_output
            else:
                fake1 = fake1_output[0]

            D_G_z_1, errD_fake1, gradient_penalty1 = _train_discriminator_with_fake(netD, fake1, opt, real1)

            # Z2 only:
            mixed_noise2 = functions.merge_noise_vectors(torch.zeros(noise2.shape, device=opt.device), noise2,
                                               opt.noise_vectors_merge_method)

            fake2_output = _generate_fake(netG, mixed_noise2, prev2)
            if opt.enable_mask:
                fake2, fake2_mask1, fake2_mask2 = fake2_output
            else:
                fake2 = fake2_output[0]

            D_G_z_2, errD_fake2, gradient_penalty2 = _train_discriminator_with_fake(netD, fake2, opt, real2)

            if opt.enable_mask:
                _, errD_mask1_fake1, _ = _train_discriminator_with_fake(netD_mask1, fake1_mask1, opt, real1)
                _, errD_mask2_fake2, _ = _train_discriminator_with_fake(netD_mask2, fake2_mask2, opt, real2)

            errD_image1 = errD_real1 + errD_fake1 + gradient_penalty1
            errD_image2 = errD_real2 + errD_fake2 + gradient_penalty2

            for discriminator_optimizer in discriminators_optimizers:
                discriminator_optimizer.step()

        err_D_img1_2plot.append(errD_image1.detach())
        err_D_img2_2plot.append(errD_image2.detach())

        ############################
        # (2) Update G network:
        # - netG: train with fake on 2 fake images from different noise source (NoiseMode.Z1, NoiseMode.Z2) against netD
        # - netG: reconstruction loss against 2 real images
        # if opt.enable_mask is ON, then:
        # - netD_mask1: train with fake on the generated fake image with mask1 applied on it against netD_mask1
        # - netD_mask2: train with fake on the generated fake image with mask2 applied on it against netD_mask2
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            errG_fake1, D_fake1_map = _generator_train_with_fake(fake1, netD)
            errG_fake2, D_fake2_map = _generator_train_with_fake(fake2, netD)
            rec_loss1, Z_opt1 = _reconstruction_loss(alpha, netG, opt, z_opt1, z_prev1, real1, NoiseMode.Z1, opt.noise_amp1)
            rec_loss2, Z_opt2 = _reconstruction_loss(alpha, netG, opt, z_opt2, z_prev2, real2, NoiseMode.Z2, opt.noise_amp2)

            if opt.enable_mask:
                mask_loss_fake1_mask1, D_mask1_fake1_mask1_map = _generator_train_with_fake(fake1_mask1, netD_mask1)
                mask_loss_fake2_mask1, D_mask1_fake2_mask1_map = _generator_train_with_fake(fake2_mask1, netD_mask1)
                mask_loss_fake1_mask2, D_mask2_fake1_mask2_map = _generator_train_with_fake(fake1_mask2, netD_mask2)
                mask_loss_fake2_mask2, D_mask2_fake2_mask2_map = _generator_train_with_fake(fake2_mask2, netD_mask2)

            optimizerG.step()

        errG_total_loss1_2plot.append(errG_fake1.detach()+rec_loss1)
        errG_total_loss2_2plot.append(errG_fake2.detach()+rec_loss2)
        errG_fake1_2plot.append(errG_fake1.detach())
        errG_fake2_2plot.append(errG_fake2.detach())
        G_total_loss = errG_fake1.detach()+rec_loss1 + errG_fake2.detach()+rec_loss2
        errG_total_loss_2plot.append(G_total_loss)
        D1_real2plot.append(D_x1)
        D2_real2plot.append(D_x2)
        D1_fake2plot.append(D_G_z_1)
        D2_fake2plot.append(D_G_z_2)
        reconstruction_loss1_2plot.append(rec_loss1)
        reconstruction_loss2_2plot.append(rec_loss2)

        if epoch % 25 == 0 or epoch == (opt.niter-1):
            logger.info('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter-1):
            plt.imsave('%s/fake_sample1.png' %  (opt.outf), functions.convert_image_np(fake1.detach()), vmin=0, vmax=1)
            plt.imsave('%s/fake_sample2.png' % (opt.outf), functions.convert_image_np(fake2.detach()), vmin=0, vmax=1)
            plt.imsave('%s/G(z_opt1).png'    % (opt.outf),
                       functions.convert_image_np(netG(Z_opt1.detach(), z_prev1)[0].detach()), vmin=0, vmax=1)
            plt.imsave('%s/G(z_opt2).png' % (opt.outf),
                       functions.convert_image_np(netG(Z_opt2.detach(), z_prev2)[0].detach()), vmin=0, vmax=1)

            # torch.save((mixed_noise1, prev1), '%s/fake1_noise_source.pth' % (opt.outf))
            # torch.save((mixed_noise2, prev2), '%s/fake2_noise_source.pth' % (opt.outf))
            # torch.save((Z_opt1, z_prev1), '%s/G(z_opt1)_noise_source.pth' % (opt.outf))
            # torch.save((Z_opt2, z_prev2), '%s/G(z_opt2)_noise_source.pth' % (opt.outf))
            torch.save(z_opt1, '%s/z_opt1.pth' % (opt.outf))
            torch.save(z_opt2, '%s/z_opt2.pth' % (opt.outf))
            if epoch == (opt.niter-1) and opt.enable_mask:
                plt.imsave('%s/fake1_mask1.png' % (opt.outf), functions.convert_image_np(fake1_mask1.detach()), vmin=0,
                           vmax=1)
                plt.imsave('%s/fake2_mask1.png' % (opt.outf), functions.convert_image_np(fake2_mask1.detach()), vmin=0,
                           vmax=1)
                plt.imsave('%s/fake1_mask2.png' % (opt.outf), functions.convert_image_np(fake1_mask2.detach()), vmin=0,
                           vmax=1)
                plt.imsave('%s/fake2_mask2.png' % (opt.outf), functions.convert_image_np(fake2_mask2.detach()), vmin=0,
                           vmax=1)
                _imsave_discriminator_map(D_fake1_map, "D_fake1_map", opt)
                _imsave_discriminator_map(D_fake2_map, "D_fake2_map", opt)
                _imsave_discriminator_map(D_mask1_fake1_mask1_map, "D_mask1_fake1_mask1_map", opt)
                _imsave_discriminator_map(D_mask1_fake2_mask1_map, "D_mask1_fake2_mask1_map", opt)
                _imsave_discriminator_map(D_mask2_fake1_mask2_map, "D_mask2_fake1_mask2_map", opt)
                _imsave_discriminator_map(D_mask2_fake2_mask2_map, "D_mask2_fake2_mask2_map", opt)


        for discriminator_scheduler in discriminators_schedulers:
            discriminator_scheduler.step()
        schedulerG.step()

    functions.save_networks(netG,netD, netD_mask1, netD_mask2,z_opt1, z_opt2,opt)

    functions.plot_learning_curves("G_loss", opt.niter, [errG_total_loss_2plot,
                                                         errG_total_loss1_2plot, errG_total_loss2_2plot,
                                                         errG_fake1_2plot, errG_fake2_2plot,
                                                         reconstruction_loss1_2plot,
                                                         reconstruction_loss2_2plot],
                                   ["G_total_loss", "G_total_loss1", "G_total_loss2",
                                    "G_fake1_loss", "G_fake2_loss",
                                    "G_recon_loss_1", "G_recon_loss_2"], opt.outf)
    d_plots = [err_D_img1_2plot, err_D_img2_2plot]
    d_labels = ["D1_total_loss", "D2_total_loss"]

    functions.plot_learning_curves("D_loss", opt.niter, d_plots, d_labels, opt.outf)
    functions.plot_learning_curves("G_vs_D_loss", opt.niter,
                                   [errG_total_loss_2plot, errG_total_loss1_2plot, errG_total_loss2_2plot,
                                    err_D_img1_2plot, err_D_img2_2plot],
                                   ["G_total_loss", "G_total_loss1", "G_total_loss2", "D1_total_loss", "D2_total_loss"],
                                   opt.outf)
    return (z_opt1, z_opt2), (in_s1, in_s2), netG
コード例 #26
0
ファイル: anchoring.py プロジェクト: tboen1/MESIGAN
def invert_model(test_image,
                 model_name,
                 scales2invert=None,
                 penalty=1e-3,
                 show=True):
    '''test_image is an array, model_name is a name'''
    Noise_Solutions = []

    parser = get_arguments()
    parser.add_argument('--input_dir',
                        help='input image dir',
                        default='Input/Images')

    parser.add_argument('--mode', default='RandomSamples')
    opt = parser.parse_args("")
    opt.input_name = model_name
    opt.reg = penalty

    if model_name == 'islands2_basis_2.jpg':  #HARDCODED
        opt.scale_factor = 0.6

    opt = functions.post_config(opt)

    ### Loading in Generators
    Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt)
    for G in Gs:
        G = functions.reset_grads(G, False)
        G.eval()

    ### Loading in Ground Truth Test Images
    reals = []  #deleting old real images
    real = functions.np2torch(test_image, opt)
    functions.adjust_scales2image(real, opt)

    real_ = functions.np2torch(test_image, opt)
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)

    ### General Padding
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    m_noise = nn.ZeroPad2d(int(pad_noise))

    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    m_image = nn.ZeroPad2d(int(pad_image))

    I_prev = None
    REC_ERROR = 0

    if scales2invert is None:
        scales2invert = opt.stop_scale + 1

    for scale in range(scales2invert):
        #for scale in range(3):

        #Get X, G
        X = reals[scale]
        G = Gs[scale]
        noise_amp = NoiseAmp[scale]

        #Defining Dimensions
        opt.nc_z = X.shape[1]
        opt.nzx = X.shape[2]
        opt.nzy = X.shape[3]

        #getting parameters for prior distribution penalty
        pdf = torch.distributions.Normal(0, 1)
        alpha = opt.reg
        #alpha = 1e-2

        #Defining Z
        if scale == 0:
            z_init = functions.generate_noise(
                [1, opt.nzx, opt.nzy], device=opt.device)  #only 1D noise
        else:
            z_init = functions.generate_noise(
                [3, opt.nzx, opt.nzy],
                device=opt.device)  #otherwise move up to 3d noise

        z_init = Variable(z_init.cuda(),
                          requires_grad=True)  #variable to optimize

        #Building I_prev
        if I_prev == None:  #first scale scenario
            in_s = torch.full(reals[0].shape, 0, device=opt.device)  #all zeros
            I_prev = in_s
            I_prev = m_image(I_prev)  #padding

        else:  #otherwise take the output from the previous scale and upsample
            I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)  #upsamples
            I_prev = m_image(I_prev)
            I_prev = I_prev[:, :, 0:X.shape[2] + 10, 0:X.shape[
                3] + 10]  #making sure that precision errors don't mess anything up
            I_prev = functions.upsampling(I_prev, X.shape[2] + 10, X.shape[3] +
                                          10)  #seems to be redundant

        LR = [2e-3, 2e-2, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1]
        Zoptimizer = torch.optim.RMSprop([z_init],
                                         lr=LR[scale])  #Defining Optimizer
        x_loss = []  #for plotting
        epochs = []  #for plotting

        niter = [
            200, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400,
            400, 400
        ]
        for epoch in range(niter[scale]):  #Gradient Descent on Z

            if scale == 0:
                noise_input = m_noise(z_init.expand(1, 3, opt.nzx,
                                                    opt.nzy))  #expand and padd
            else:
                noise_input = m_noise(z_init)  #padding

            z_in = noise_amp * noise_input + I_prev
            G_z = G(z_in, I_prev)

            x_recLoss = F.mse_loss(G_z, X)  #MSE loss

            logProb = pdf.log_prob(z_init).mean()  #Gaussian loss

            loss = x_recLoss - (alpha * logProb.mean())

            Zoptimizer.zero_grad()
            loss.backward()
            Zoptimizer.step()

            #losses['rec'].append(x_recLoss.data[0])
            #print('Image loss: [%d] loss: %0.5f' % (epoch, x_recLoss.item()))
            #print('Noise loss: [%d] loss: %0.5f' % (epoch, z_recLoss.item()))
            x_loss.append(loss.item())
            epochs.append(epoch)

            REC_ERROR = x_recLoss

        if show:
            plt.plot(epochs, x_loss, label='x_loss')
            plt.legend()
            plt.show()

        I_prev = G_z.detach(
        )  #take final output, maybe need to edit this line something's very very fishy

        _ = show_image(X, show, 'target')
        reconstructed_image = show_image(I_prev, show, 'output')
        _ = show_image(noise_input.detach().cpu(), show, 'noise')

        Noise_Solutions.append(noise_input.detach())
    return Noise_Solutions, reconstructed_image, REC_ERROR
コード例 #27
0
def generate_gif(Gs,Zs,reals,NoiseAmp,opt,alpha=0.1,beta=0.9,start_scale=2,fps=10):

    in_s = torch.full(Zs[0].shape, 0, device=opt.device)
    images_cur = []
    count = 0

    for G,Z_opt,noise_amp,real in zip(Gs,Zs,NoiseAmp,reals):
        pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
        nzx = Z_opt.shape[2]
        nzy = Z_opt.shape[3]
        #pad_noise = 0
        #m_noise = nn.ZeroPad2d(int(pad_noise))
        m_image = nn.ZeroPad2d(int(pad_image))
        images_prev = images_cur
        images_cur = []
        if count == 0:
            z_rand = functions.generate_noise([1,nzx,nzy], device=opt.device)
            z_rand = z_rand.expand(1,3,Z_opt.shape[2],Z_opt.shape[3])
            z_prev1 = 0.95*Z_opt +0.05*z_rand
            z_prev2 = Z_opt
        else:
            z_prev1 = 0.95*Z_opt +0.05*functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device)
            z_prev2 = Z_opt

        for i in range(0,100,1):
            if count == 0:
                z_rand = functions.generate_noise([1,nzx,nzy], device=opt.device)
                z_rand = z_rand.expand(1,3,Z_opt.shape[2],Z_opt.shape[3])
                diff_curr = beta*(z_prev1-z_prev2)+(1-beta)*z_rand
            else:
                diff_curr = beta*(z_prev1-z_prev2)+(1-beta)*(functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device))

            z_curr = alpha*Z_opt+(1-alpha)*(z_prev1+diff_curr)
            z_prev2 = z_prev1
            z_prev1 = z_curr

            if images_prev == []:
                I_prev = in_s
            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)
                I_prev = I_prev[:, :, 0:real.shape[2], 0:real.shape[3]]
                I_prev = m_image(I_prev)
            if count < start_scale:
                z_curr = Z_opt

            z_in = noise_amp*z_curr+I_prev
            I_curr = G(z_in.detach(),I_prev)

            if (count == len(Gs)-1):
                I_curr = functions.denorm(I_curr).detach()
                I_curr = I_curr[0,:,:,:].cpu().numpy()
                I_curr = I_curr.transpose(1, 2, 0)*255
                I_curr = I_curr.astype(np.uint8)

            images_cur.append(I_curr)
        count += 1
    dir2save = functions.generate_dir2save(opt)
    try:
        os.makedirs('%s/start_scale=%d' % (dir2save,start_scale) )
    except OSError:
        pass
    imageio.mimsave('%s/start_scale=%d/alpha=%f_beta=%f.gif' % (dir2save,start_scale,alpha,beta),images_cur,fps=fps)
    del images_cur
コード例 #28
0
def train_single_scale(netD,
                       netG,
                       reals,
                       masks,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):

    real = reals[len(Gs)]
    mask = masks[len(Gs)]

    opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    # Here we calculate the decrease in output size due to conv layers.
    # required for setting the size of discriminators map.
    # TODO: currently, calculation doesn't consider opt.stride
    if (opt.ker_size % 2 == 0):
        r = (opt.num_layer) * (opt.ker_size - 1)
    else:  #(opt.ker_size % 2 != 0):
        r = (opt.num_layer) * (opt.ker_size - 1) / 2
    r = int(r)

    _, _, h, w = mask.size()
    discriminators_mask = mask.detach()[:, :, r:h - r,
                                        r:w - r][:, 0, :, :].unsqueeze(0)
    _, _, h, w = discriminators_mask.size()

    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                           device=opt.device)
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []
    norm = []

    norm.append(1)
    norm.append((h * w) / discriminators_mask.sum().item())

    plt.imsave('%s/mask.png' % (opt.outf),
               functions.convert_image_np(real * mask))

    for epoch in range(opt.niter):
        if (Gs == []) & (opt.mode != 'SR_train'):
            if (epoch == 0):
                z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                                 device=opt.device)
                z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()
            output = netD(real).to(opt.device)
            output = output * discriminators_mask
            D_real_map = output.detach()
            errD_real = -(output.mean()) * norm[opt.norm]  #-a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j == 0) & (epoch == 0):
                if (Gs == []) & (opt.mode != 'SR_train'):
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev
                    prev = m_image(prev)
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            fake = netG(noise.detach(), prev)
            output = netD(fake.detach())
            errD_fake = output.mean()
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device,
                discriminators_mask * norm[opt.norm])
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            D_fake_map = output.detach()
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                netG_out = netG(Z_opt.detach(), z_prev)
                netG_out = netG_out * mask
                real = real * mask
                rec_loss = alpha * loss(netG_out, real)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errD2plot.append(errD.detach())
        errG2plot.append(errG.detach())
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch % 25 == 0 or epoch == (opt.niter - 1):
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/D_fake.png' % (opt.outf),
                       functions.convert_image_np(D_fake_map))
            plt.imsave('%s/D_real.png' % (opt.outf),
                       functions.convert_image_np(D_real_map))
            #plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            #plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            #plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
            #plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)

            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

            with open('%s/D_real2plot' % (opt.outf), "wb") as fp:  # Pickling
                pickle.dump(D_real2plot, fp)
            with open('%s/D_fake2plot' % (opt.outf), "wb") as fp:  # Pickling
                pickle.dump(D_fake2plot, fp)
            # with open('%s/D_real2plot' % (opt.outf), "rb") as fp:  # Unpickling
            #     D_fake2plot = pickle.load(fp)

            with open('%s/errD2plot' % (opt.outf), "wb") as fp:  # Pickling
                pickle.dump(errD2plot, fp)
            with open('%s/errG2plot' % (opt.outf), "wb") as fp:  # Pickling
                pickle.dump(errG2plot, fp)
            with open('%s/z_opt2plot' % (opt.outf), "wb") as fp:  # Pickling
                pickle.dump(z_opt2plot, fp)

        schedulerD.step()
        schedulerG.step()

    # plt.imsave('%s/masked_img.png'   % (opt.outf), functions.convert_image_np(real*mask))
    functions.save_networks(netG, netD, z_opt, opt)
    return z_opt, in_s, netG
コード例 #29
0
def train_single_scale(netD,
                       netG,
                       reals,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):

    real = reals[len(Gs)]
    opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    # generate_noise(size,num_samp=1,device='cuda',type='gaussian', scale=1)
    # size: [opt.nc_z, opt.nzx, opt.nzy]
    # z_opt: input noise for calculating reconstruction loss in Generator
    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                           device=opt.device)
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    # TODO: these plots are not visualized
    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    for epoch in range(opt.niter):
        start_time = time.time()
        if (Gs == []) & (opt.mode != 'SR_train'):
            # Bottom generator, here without zero init
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                             device=opt.device)
            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            # noise_: input noise for the discriminator
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()

            output = netD(real).to(opt.device)
            #D_real_map = output.detach()
            errD_real = -output.mean()  #-a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j == 0) & (epoch == 0):
                # Initialize prev and z_prev
                # prev: image outputs from previous level
                # z_prev: image outputs from previous level of fixed noise z_opt
                if (Gs == []) & (opt.mode != 'SR_train'):
                    # in_s and prev are both noise
                    # z_prev are also noise
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev
                    prev = m_image(prev)
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            fake = netG(noise.detach(), prev)
            output = netD(fake.detach())
            errD_fake = output.mean()
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach())

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            #D_fake_map = output.detach()
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch > 1 and (epoch % 100 == 0 or epoch == (opt.niter - 1)):
            total_time = time.time() - start_time
            start_time = time.time()
            print('scale %d:[%d/%d], total time: %f' %
                  (len(Gs), epoch, opt.niter, total_time))
            memory = torch.cuda.max_memory_allocated()
            # print('allocated memory: %dG %dM %dk %d' %
            #         ( memory // (1024*1024*1024),
            #           (memory // (1024*1024)) % 1024,
            #           (memory // 1024) % 1024,
            #           memory % 1024 ))
            print('allocated memory: %.03f GB' % (memory /
                                                  (1024 * 1024 * 1024 * 1.0)))

        # if epoch % 500 == 0 or epoch == (opt.niter-1):
        if epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            # plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            # plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            # plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            plt.imsave('%s/prev.png' % (opt.outf),
                       functions.convert_image_np(prev),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/prev_plus_noise.png' % (opt.outf),
                       functions.convert_image_np(noise),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/z_prev.png' % (opt.outf),
                       functions.convert_image_np(z_prev),
                       vmin=0,
                       vmax=1)
            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

        schedulerD.step()
        schedulerG.step()

    functions.save_networks(netG, netD, z_opt, opt)
    return z_opt, in_s, netG
コード例 #30
0
ファイル: manipulate.py プロジェクト: oidelima/SinGAN
def SinGAN_generate(Gs,
                    Zs,
                    reals,
                    crops,
                    masks,
                    NoiseAmp,
                    opt,
                    in_s=None,
                    scale_v=1,
                    scale_h=1,
                    n=0,
                    gen_start_scale=0,
                    num_samples=20,
                    mask_locs=None):
    #if torch.is_tensor(in_s) == False:
    Gs[-1].train()

    if in_s == None:
        in_s = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                          0,
                          device=opt.device)

    for i in range(0, num_samples, 1):

        eye = functions.generate_eye_mask(opt, masks[-1],
                                          0)  #generate eye in random location
        eye_colored = eye.clone()
        if opt.random_eye_color:
            eye_color = functions.get_eye_color(reals[-1])
            opt.eye_color = eye_color
            eye_colored[:, 0, :, :] *= (eye_color[0] / 255)
            eye_colored[:, 1, :, :] *= (eye_color[1] / 255)
            eye_colored[:, 2, :, :] *= (eye_color[2] / 255)

        noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                          device=opt.device)

        pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2
        pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
        pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)

        m_noise = nn.ZeroPad2d(int(pad_noise))
        m_image = nn.ZeroPad2d(int(pad_image))

        noise_ = m_noise(noise_)

        prev = functions.draw_concat(Gs, Zs, reals, crops, masks, eye_colored,
                                     NoiseAmp, in_s, 'rand', m_noise, m_image,
                                     opt)
        prev = m_image(prev)

        noise = opt.noise_amp * noise_ + prev

        G_input = functions.make_input(noise, masks[-1], eye_colored)
        fake_background = Gs[-1](G_input.detach(), prev)

        border = False  #TODO

        if opt.random_crop:
            crop_size = crops[-1].size()[2]
            crop, h_idx, w_idx = functions.random_crop(reals[-1], crop_size)
            I_curr, fake_ind, eye_ind = functions.gen_fake(
                crop,
                fake_background,
                masks[-1],
                eye,
                opt.eye_color,
                opt,
                border,
                mask_loc=mask_locs[i])
            full_fake = reals[-1].clone()
            full_fake[:, :, h_idx:h_idx + crop_size,
                      w_idx:w_idx + crop_size] = I_curr
            full_mask = torch.zeros_like(full_fake)
            full_mask[:, :, h_idx:h_idx + crop_size,
                      w_idx:w_idx + crop_size] = fake_ind
        else:
            I_curr, fake_ind, eye_ind = functions.gen_fake(
                reals[-1],
                fake_background,
                masks[-1],
                eye,
                opt.eye_color,
                opt,
                border,
                mask_loc=mask_locs[i])

        if opt.mode == 'train':
            dir2save = '%s/RandomSamples/%s/SinGAN/%s' % (
                opt.out, opt.input_name[:-4], opt.run_name)
        else:
            dir2save = functions.generate_dir2save(opt)
        try:
            os.makedirs(dir2save + "/fake")
            os.makedirs(dir2save + "/background")
            os.makedirs(dir2save + "/mask")
            os.makedirs(dir2save + "/eye")
            if opt.random_crop:
                os.makedirs(dir2save + "/full_fake")
                os.makedirs(dir2save + "/full_mask")

        except OSError:
            pass
        if (opt.mode != "harmonization") & (opt.mode != "editing") & (
                opt.mode != "SR") & (opt.mode != "paint2image"):
            plt.imsave('%s/%s/%d.png' % (dir2save, "fake", i),
                       functions.convert_image_np(I_curr.detach()))
            plt.imsave('%s/%s/%d.png' % (dir2save, "background", i),
                       functions.convert_image_np(fake_background.detach()))
            plt.imsave('%s/%s/%d.png' % (dir2save, "mask", i),
                       functions.convert_image_np(fake_ind.detach()))
            plt.imsave('%s/%s/%d.png' % (dir2save, "eye", i),
                       functions.convert_image_np(eye_ind.detach()))
            if opt.random_crop:
                plt.imsave('%s/%s/%d.png' % (dir2save, "full_fake", i),
                           functions.convert_image_np(full_fake.detach()))
                plt.imsave('%s/%s/%d.png' % (dir2save, "full_mask", i),
                           functions.convert_image_np(full_mask.detach()))