Пример #1
0
def train(opt,Gs,Zs,reals1,reals2,NoiseAmp):
    real1_, real2_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real1 = imresize(real1_,opt.scale1,opt)
    real2 = imresize(real2_, opt.scale1, opt)
    reals1 = functions.creat_reals_pyramid(real1,reals1,opt)
    reals2 = functions.creat_reals_pyramid(real2, reals2, opt)
    nfc_prev = 0
    while scale_num<opt.stop_scale+1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
                pass

        plt.imsave('%s/real1_scale.png' %  (opt.outf), functions.convert_image_np(reals1[scale_num]), vmin=0, vmax=1)
        plt.imsave('%s/real2_scale.png' % (opt.outf), functions.convert_image_np(reals2[scale_num]), vmin=0, vmax=1)

        D_curr,G_curr = init_models(opt)
        if (nfc_prev==opt.nfc):
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,scale_num-1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1)))

        z_curr,in_s,G_curr = train_single_scale(D_curr,G_curr,reals1,reals2,Gs,Zs,in_s,NoiseAmp,opt)

        G_curr = functions.reset_grads(G_curr,False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr,False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals1, '%s/reals1.pth' % (opt.out_))
        torch.save(reals2, '%s/reals2.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num+=1
        nfc_prev = opt.nfc
        del D_curr,G_curr
    return
Пример #2
0
def train(opt,Gs,Zs,reals,NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    # cur_scale_level: current level from coarest to finest.
    cur_scale_level = 0
    # scale1: for the largest patch size, what ratio wrt the image shape
    reals = functions.creat_reals_pyramid(real_,reals,opt)
    nfc_prev = 0

    # Train including opt.stop_scale
    while cur_scale_level < opt.stop_scale+1:
        # nfc: number of out channels in conv block
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)

        # out_: output directory
        # outf: output folder, with scale
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,cur_scale_level)
        try:
            os.makedirs(opt.outf)
        except OSError:
                pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' %  (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1)

        D_curr,G_curr = init_models(opt)
        # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper)
        if (nfc_prev==opt.nfc):
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1)))

        # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor.
        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr,False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr,False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        cur_scale_level+=1
        nfc_prev = opt.nfc
        del D_curr,G_curr
        torch.cuda.empty_cache()
    return
Пример #3
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0
    netD_optimizer = tf.keras.optimizers.Adam(learning_rate=opt.lr_d,
                                              beta_1=opt.beta1,
                                              beta_2=0.999)
    netG_optimizer = tf.keras.optimizers.Adam(learning_rate=opt.lr_g,
                                              beta_1=opt.beta1,
                                              beta_2=0.999)

    while scale_num < opt.stop_scale + 1:

        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)

        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)
        D_curr, G_curr = init_models(opt)
        if nfc_prev == opt.nfc:
            D_curr.load_weights('%s/%d/netD' % (opt.out_, scale_num - 1))
            G_curr.load_weights('%s/%d/netG' % (opt.out_, scale_num - 1))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt,
                                                  scale_num, netG_optimizer,
                                                  netD_optimizer)

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)
        with open('%s/Zs.pkl' % (opt.out_), 'wb') as f:
            pickle.dump(Zs, f)
        with open('%s/reals.pkl' % (opt.out_), 'wb') as f:
            pickle.dump(reals, f)
        with open('%s/NoiseAmp.pkl' % (opt.out_), 'wb') as f:
            pickle.dump(NoiseAmp, f)
        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return None
Пример #4
0
def test_pyramid(images):
    parser = get_arguments()
    parser.add_argument('--input_dir',
                        help='input image dir',
                        default='Input/Images')
    #parser.add_argument('--input_name', help='input image name', required=True)
    parser.add_argument('--mode', help='task to be done', default='train')
    opt = parser.parse_args("")
    opt.input_name = 'blank'
    opt = functions.post_config(opt)

    real = functions.np2torch(images[0], opt)
    functions.adjust_scales2image(real, opt)

    all_reals = []
    for image in images:
        reals = []
        real_ = functions.np2torch(image, opt)
        real = imresize(real_, opt.scale1, opt)
        reals = functions.creat_reals_pyramid(real, reals, opt)
        all_reals.append(reals)

    return np.array(all_reals).T
Пример #5
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)

    # 不同规格数据形成的列表
    reals = functions.creat_reals_pyramid(real, reals, opt)
    # print('reals', reals)  # 各个scale的图形形成的列表

    # plt.imsave('Output/real_scale_0.png', functions.convert_image_np(reals[0]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_1.png', functions.convert_image_np(reals[1]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_2.png', functions.convert_image_np(reals[2]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_3.png', functions.convert_image_np(reals[3]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_4.png', functions.convert_image_np(reals[4]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_5.png', functions.convert_image_np(reals[5]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_6.png', functions.convert_image_np(reals[6]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_7.png', functions.convert_image_np(reals[7]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_8.png', functions.convert_image_np(reals[8]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_9.png', functions.convert_image_np(reals[9]), vmin=0, vmax=1)

    nfc_prev = 0

    # opt.stop_scale = 9   循环9次
    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        print('opt.nfc', opt.nfc)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        print('opt.min_nfc', opt.min_nfc)
        if opt.fast_training:
            if (scale_num > 0) & (scale_num % 4 == 0):
                opt.niter = opt.niter // 2

        # out_是生成根路径
        opt.out_ = functions.generate_dir2save(opt)
        # outf是每个scale路径
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        # plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        # 保存原始图像
        plt.imsave('%s/original.png' % (opt.out_),
                   functions.convert_image_np(real_),
                   vmin=0,
                   vmax=1)
        # 在每个scale中保存real_scale
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        # return netD, netG  目前的D和G,D_curr
        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        # train_single_scale()返回:z_opt, in_s, netG
        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        print('G_curr', G_curr)
        G_curr.eval()
        print(G_curr.eval())
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Пример #6
0
def test_generate(model_name,
                  anchor_image=None,
                  direction=None,
                  transfer=None,
                  noise_solutions=None,
                  factor=0.25,
                  base=None,
                  insert_limit=4):
    #direction = 'L, R, T, B'

    parser = get_arguments()
    parser.add_argument('--input_dir',
                        help='input image dir',
                        default='Input/Images')
    parser.add_argument('--mode',
                        help='random_samples | random_samples_arbitrary_sizes',
                        default='random_samples')
    # for random_samples:
    parser.add_argument('--gen_start_scale',
                        type=int,
                        help='generation start scale',
                        default=0)
    opt = parser.parse_args("")
    opt.input_name = model_name

    opt = functions.post_config(opt)
    Gs = []
    Zs = []
    reals = []
    NoiseAmp = []

    opt.input_name = 'island_basis_0.jpg'  #grabbing image that exists...
    real = functions.read_image(opt)
    #opt.input_name = anchor #CHANGE TO ANCHOR HERE
    #anchor = functions.read_image(opt)
    functions.adjust_scales2image(real, opt)

    opt.input_name = 'test1.jpg'  #grabbing model that we want
    Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt)

    #dummy stuff for dimensions
    reals = []
    real_ = real
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    in_s = functions.generate_in2coarsest(reals, 1, 1, opt)

    array = SinGAN_anchor_generate(Gs,
                                   Zs,
                                   reals,
                                   NoiseAmp,
                                   opt,
                                   gen_start_scale=opt.gen_start_scale,
                                   anchor_image=anchor_image,
                                   direction=direction,
                                   transfer=transfer,
                                   noise_solutions=noise_solutions,
                                   factor=factor,
                                   base=base,
                                   insert_limit=insert_limit)
    return array
Пример #7
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    # cur_scale_level: current level from coarest to finest.
    cur_scale_level = 0
    # scale1: for the largest patch size, what ratio wrt the image shape
    reals = functions.creat_reals_pyramid(real_, reals, opt)
    nfc_prev = 0

    # Train including opt.stop_scale
    while cur_scale_level < opt.stop_scale + 1:
        # nfc: number of out channels in conv block
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)),
                      128)
        opt.min_nfc = min(
            opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)

        # out_: output directory
        # outf: output folder, with scale
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, cur_scale_level)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[cur_scale_level]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, cur_scale_level - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, cur_scale_level - 1)))

        # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor.
        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()

        #################################################################################
        # Visualzie weights
        def visualize_weights(modules, fig_name):
            ori_weights = torch.tensor([]).cuda()
            for m in modules:
                cur_params = m.weight.data.flatten()
                ori_weights = torch.cat((ori_weights, cur_params))
                cur_params = m.bias.data.flatten()
                ori_weights = torch.cat((ori_weights, cur_params))
            # sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement())
            ori_weights = ori_weights.cpu().numpy()
            ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100)
            plt.savefig("%s/%s.png" % (opt.outf, fig_name))
            plt.close()

        # Pruning all weights
        modules = [
            G_curr.head.conv, G_curr.head.norm, G_curr.body.block1.conv,
            G_curr.body.block1.norm, G_curr.body.block2.conv,
            G_curr.body.block2.norm, G_curr.body.block3.conv,
            G_curr.body.block3.norm, G_curr.tail[0]
        ]
        parameters_to_prune = ((G_curr.head.conv, 'weight'), (G_curr.head.conv,
                                                              'bias'),
                               (G_curr.head.norm, 'weight'), (G_curr.head.norm,
                                                              'bias'),
                               (G_curr.body.block1.conv,
                                'weight'), (G_curr.body.block1.conv, 'bias'),
                               (G_curr.body.block1.norm,
                                'weight'), (G_curr.body.block1.norm, 'bias'),
                               (G_curr.body.block2.conv,
                                'weight'), (G_curr.body.block2.conv, 'bias'),
                               (G_curr.body.block2.norm,
                                'weight'), (G_curr.body.block2.norm, 'bias'),
                               (G_curr.body.block3.conv,
                                'weight'), (G_curr.body.block3.conv,
                                            'bias'), (G_curr.body.block3.norm,
                                                      'weight'),
                               (G_curr.body.block3.norm,
                                'bias'), (G_curr.tail[0],
                                          'weight'), (G_curr.tail[0], 'bias'))

        visualize_weights(modules, 'ori')

        # Prune weights
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=0.2,
        )

        for m in modules:
            prune.remove(m, 'weight')
            prune.remove(m, 'bias')

        visualize_weights(modules, 'prune')
        G_curr.half()
        #################################################################################
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        cur_scale_level += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
        torch.cuda.empty_cache()
    return
Пример #8
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        if opt.fast_training:
            if (scale_num > 0) & (scale_num % 4 == 0):
                opt.niter = opt.niter // 2
        '''
        if (scale_num == opt.stop_scale):
            opt.nfc = 128
            opt.min_nfc = 128
        '''
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Пример #9
0
def get_reals(reals, opt, image_name):
    real_ = functions.read_image(opt, image_name)
    real = imresize(real_,opt.scale1,opt)
    reals = functions.creat_reals_pyramid(real,reals,opt)
    return reals
Пример #10
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)

    print('scale_num:', len(reals))
    for _reals in reals:
        print('image_size:', _reals.size())

    nfc_prev = 0

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp,
                                                  errD2plot, errG2plot,
                                                  D_real2plot, D_fake2plot,
                                                  z_opt2plot, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr

        functions.my_plot(errD2plot, errG2plot, z_opt2plot, opt)

    return
Пример #11
0
def SinGAN_generate(Gs,
                    Zs,
                    reals,
                    NoiseAmp,
                    opt,
                    in_s=None,
                    scale_v=1,
                    scale_h=1,
                    n=0,
                    gen_start_scale=0,
                    num_samples=50):
    real = functions.read_image(opt)

    real = real.numpy()
    real = resize(real, reals[-1].shape)
    real = torch.from_numpy(real)

    new_reals = creat_reals_pyramid(real, [], opt)
    buffer = []

    for new_real, real in zip(new_reals, reals):
        ele = new_real.numpy()
        ele = resize(ele, real.shape)
        ele = torch.from_numpy(ele)
        buffer.append(ele)
    reals = buffer

    for i, real_img in enumerate(reals):
        dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (
            opt.out, opt.input_name[:-4], gen_start_scale)
        plt.imsave('%s/%s_%d.png' % (dir2save, "real", i),
                   functions.convert_image_np(real_img.detach()),
                   vmin=0,
                   vmax=1)

    if in_s is None:
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    images_cur = []

    for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp):
        pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2
        m = nn.ZeroPad2d(int(pad1))
        nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v
        nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h
        # For Section IV
        # if n == 0:
        #     images_prev = images_cur
        # else:
        #     new_img_prev = []
        #     for img in images_cur:
        #         ele = reals[n].numpy()
        #         ele = resize(ele, img.shape)
        #         ele = torch.from_numpy(ele)
        #         new_img_prev.append(ele)
        #     images_prev = new_img_prev

        images_prev = images_cur

        # if n != 0:
        #     dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale)
        #     plt.imsave('%s/%s_%d.png' % (dir2save, "img_cur", n), functions.convert_image_np(images_prev[0].detach()), vmin=0,vmax=1)
        #     plt.imsave('%s/%s_%d.png' % (dir2save, "img_prev", n), functions.convert_image_np(images_cur[0].detach()), vmin=0,vmax=1)

        images_cur = []

        for i in range(0, num_samples, 1):
            if n == 0:
                z_curr = functions.generate_noise([1, nzx, nzy],
                                                  device=opt.device)
                z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3])
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z, nzx, nzy],
                                                  device=opt.device)
                z_curr = m(z_curr)

            if images_prev == []:
                I_prev = m(in_s)

            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)
                if opt.mode != "SR":
                    I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]),
                                    0:round(scale_h * reals[n].shape[3])]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]]
                    I_prev = functions.upsampling(I_prev, z_curr.shape[2],
                                                  z_curr.shape[3])
                else:
                    I_prev = m(I_prev)

            if n < gen_start_scale:
                z_curr = Z_opt

            z_in = noise_amp * (z_curr) + I_prev
            if opt.skip != '' and int(opt.skip) == n:
                I_curr = I_prev
            else:
                I_curr = G(z_in.detach(), I_prev)

            if n == len(reals) - 1:
                if opt.mode == 'train':
                    dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (
                        opt.out, opt.input_name[:-4], gen_start_scale)
                else:
                    dir2save = functions.generate_dir2save(opt)
                try:
                    os.makedirs(dir2save)
                except OSError:
                    pass
                if (opt.mode != "harmonization") & (opt.mode != "editing") & (
                        opt.mode != "SR") & (opt.mode != "paint2image"):
                    plt.imsave('%s/%d.png' % (dir2save, i),
                               functions.convert_image_np(I_curr.detach()),
                               vmin=0,
                               vmax=1)
                    # plt.imsave('%s/%d_%d.png' % (dir2save, i, n), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1)

            # For Section VI
            # if opt.mode == 'train':
            #     dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale)
            # else:
            #     dir2save = functions.generate_dir2save(opt)
            # try:
            #     os.makedirs(dir2save)
            # except OSError:
            #     pass
            # if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"):
            #     plt.imsave('%s/%d_%d.png' % (dir2save, i, n), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1)

            images_cur.append(I_curr)
        n += 1
    return I_curr.detach()
Пример #12
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    #If training for inpainting
    if opt.mode == "inpainting":
        #Importing mask image in space [0,255]

        if opt.on_drive != None:
            mask = img.imread('%s/%s/%s' %
                              (opt.on_drive, opt.input_dir, opt.mask_name))
        else:
            mask = img.imread('%s/%s' % (opt.input_dir, opt.mask_name))

        #Convert mask to [O,1] space, 0 is masked out area, 1 everywhere else
        mask = 1 - (mask / 255)
        #Loading mask to torch tensor
        mask = torch.from_numpy(mask)

        #Resizing the initial mask
        mask = mask[:, :, :, None].view([1, 3, mask.shape[0], mask.shape[1]])
        mask = imresize(mask, opt.scale1, opt)

        #Creating mask pyramid
        opt.masks = functions.creat_reals_pyramid(mask, [], opt)

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        #plt.imsave('%s/real_scale.png' %  (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Пример #13
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    # cur_scale_level: current level from coarest to finest.
    cur_scale_level = 0
    # real = imresize(real_,opt.scale1,opt)
    # scale1: for the largest patch size, what ratio wrt the image shape
    reals = functions.creat_reals_pyramid(real_, reals, opt)
    upsamples = [reals[0]]
    diffs = [reals[0]]

    # Need to generate opt.stop_scale, thus upsample opt.stop_scale-1
    for i in range(opt.stop_scale):
        cur_img = reals[i]
        next_img = reals[i + 1]
        _, b, c, d = next_img.shape
        upsampled_real = imresize_to_shape(cur_img, (c, d, b), opt)
        upsamples.append(upsampled_real)
        diff = (next_img - upsampled_real).abs() - 1  # [-1, 1]
        diffs.append(diff)

    # Train including opt.stop_scale
    while cur_scale_level < opt.stop_scale + 1:
        # nfc: number of out channels in conv block
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)),
                      128)
        opt.min_nfc = min(
            opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)

        # out_: output directory
        # outf: output folder, with scale
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, cur_scale_level)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[cur_scale_level]),
                   vmin=0,
                   vmax=1)
        plt.imsave('%s/diff.png' % (opt.outf),
                   functions.convert_image_np(diffs[cur_scale_level].detach()),
                   vmin=0,
                   vmax=1)
        plt.imsave('%s/upsampled.png' % (opt.outf),
                   functions.convert_image_np(
                       upsamples[cur_scale_level].detach()),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper)

        # No need to reload, since training in parallel
        # if (nfc_prev==opt.nfc):
        #     G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1)))
        #     D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1)))

        # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor.
        # z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt)
        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals,
                                                  upsamples, cur_scale_level,
                                                  in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        cur_scale_level += 1
        # nfc_prev = opt.nfc
        del D_curr, G_curr
        torch.cuda.empty_cache()
    return
Пример #14
0
def train(opt,Gs,Zs,reals,NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_,opt.scale1,opt)
    print ("real.shape:",real.shape)  #real.shape: torch.Size([1, 3, 186, 248])
    reals = functions.creat_reals_pyramid(real,reals,opt)
    for i in reals: print (i.shape)
    '''
    torch.Size([1, 3, 20, 27])
    torch.Size([1, 3, 27, 36])
    torch.Size([1, 3, 35, 47])
    torch.Size([1, 3, 47, 62])
    torch.Size([1, 3, 61, 82])
    torch.Size([1, 3, 81, 108])
    torch.Size([1, 3, 107, 143])
    torch.Size([1, 3, 141, 188])
    torch.Size([1, 3, 186, 248])
    '''
    nfc_prev = 0
    print ("total %d scales.."% (opt.stop_scale))
    while scale_num<opt.stop_scale+1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        if opt.fast_training:
            if (scale_num > 0) & (scale_num % 4==0):
                opt.niter = opt.niter//2

        '''
        if (scale_num == opt.stop_scale):
            opt.nfc = 128
            opt.min_nfc = 128
        '''
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,scale_num)
        print ("out dir:",opt.outf)
        try:
            os.makedirs(opt.outf)
        except OSError:
                pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' %  (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1)
        #生成D和G,由于是全卷积网络,不需要关心输入大小
        D_curr,G_curr = init_models(opt)
        if (nfc_prev==opt.nfc):  #如果两次网络中的层数一样,则可以finetune上一个scale的网络参数
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,scale_num-1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1)))
            
        #对该scale的网络进行训练,这里每训练一个scale的网络就换,保持显存占用一直很低,不到3g
        z_curr,in_s,G_curr = train_single_scale(D_curr,G_curr,reals,Gs,Zs,in_s,NoiseAmp,opt)

        G_curr = functions.reset_grads(G_curr,False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr,False)
        D_curr.eval()
        
        #保留训练后的G网络
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num+=1
        nfc_prev = opt.nfc
        
        #这里将D和G网络删除,减小显存?
        del D_curr,G_curr
    return
Пример #15
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    #from the name get the picture
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    #scale1 is defined from adjust2scale, saved in opt
    real = imresize(real_, opt.scale1, opt)
    # a list of resized images
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    #for scale 0 to stop scale
    for scale_num in tqdm_notebook(range(opt.stop_scale + 1),
                                   desc=opt.input_name,
                                   leave=True):

        #define the number of channels in this scale
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)

        #define the minimum number of channels in this scale
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        #the output main directory
        opt.out_ = functions.generate_dir2save(opt)

        #the output sub directory for each scale
        opt.outf = f'{opt.out_}/{scale_num}'

        #if need create the directory
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #save the resized original image for this scale
        plt.imsave(f'{opt.outf}/real_scale.png',
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        #create the generator and discriminator
        D_curr, G_curr = init_models(opt)

        #if the number of channel of previous layer = current nfc
        if (nfc_prev == opt.nfc):

            #direct load the weightfrom last model
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        #train a single scale, get the current z, in_s, generator
        z_curr, in_s, G_curr = train_single_scale(
            D_curr,  #current discriminator
            G_curr,  #current generator
            reals,  #the list of all resized data
            Gs,  #generator list
            Zs,  # a list initialized as []
            in_s,  #
            NoiseAmp,  #
            opt  #parameters
        )

        #make current G and D untrainable,set it into eval mode
        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        # save them into the list
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        #save the checkpoints
        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        nfc_prev = opt.nfc
        #delete the D and G for memory
        del D_curr, G_curr
    return
Пример #16
0
def train(opt,Gs,Zs,reals,NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    # cur_scale_level: current level from coarest to finest.
    cur_scale_level = 0
    # scale1: for the largest patch size, what ratio wrt the image shape
    reals = functions.creat_reals_pyramid(real_,reals,opt)
    nfc_prev = 0

    # Train including opt.stop_scale
    while cur_scale_level < opt.stop_scale+1:
        # nfc: number of out channels in conv block
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)

        # out_: output directory
        # outf: output folder, with scale
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,cur_scale_level)
        try:
            os.makedirs(opt.outf)
        except OSError:
                pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' %  (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1)

        D_curr,G_curr = init_models(opt)
        # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper)
        if (nfc_prev==opt.nfc):
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1)))

        # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor.
        if fine_tune:
          z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, warmup_steps)
        else:
          z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter)


        G_curr = functions.reset_grads(G_curr,False)
        # D_curr = functions.reset_grads(D_curr,False)
        G_curr.eval()
        # D_curr.eval()

        #################################################################################
        # Visualzie weights
        def visualize_weights(modules, fig_name):
            ori_weights = torch.tensor([]).cuda()
            for m in modules:
                cur_params = m.weight.data.flatten()
                ori_weights = torch.cat((ori_weights, cur_params))
                # cur_params = m.bias.data.flatten()
                # ori_weights = torch.cat((ori_weights, cur_params))
            sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement())
            print(sparsity, ori_weights.nelement())
            ori_weights = ori_weights.cpu().numpy()
            ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100)
            plt.savefig("%s/%s.png" % (opt.outf, fig_name))
            plt.close()

        # Pruning weights Structured or Non-structured
        if not structured:
            modules = [G_curr.head.conv, G_curr.head.norm,
                    G_curr.body.block1.conv, G_curr.body.block1.norm,
                    G_curr.body.block2.conv, G_curr.body.block2.norm,
                    G_curr.body.block3.conv, G_curr.body.block3.norm,
                    G_curr.tail[0]]
            parameters_to_prune = (
                (G_curr.head.conv, 'weight'),
                (G_curr.head.norm, 'weight'),
                (G_curr.body.block1.conv, 'weight'),
                (G_curr.body.block1.norm, 'weight'),
                (G_curr.body.block2.conv, 'weight'),
                (G_curr.body.block2.norm, 'weight'),
                (G_curr.body.block3.conv, 'weight'),
                (G_curr.body.block3.norm, 'weight'),
                (G_curr.tail[0], 'weight'),
                (G_curr.head.conv, 'bias'),
                (G_curr.head.norm, 'bias'),
                (G_curr.body.block1.conv, 'bias'),
                (G_curr.body.block1.norm, 'bias'),
                (G_curr.body.block2.conv, 'bias'),
                (G_curr.body.block2.norm, 'bias'),
                (G_curr.body.block3.conv, 'bias'),
                (G_curr.body.block3.norm, 'bias'),
                (G_curr.tail[0], 'bias'),
            )

            visualize_weights(modules, 'ori')

            # Prune weights
            prune.global_unstructured(
                parameters_to_prune,
                pruning_method=prune.L1Unstructured,
                amount=pruning_amount,
            )
        else:
            modules = [G_curr.head.conv,
            G_curr.body.block1.conv,
            G_curr.body.block2.conv,
            G_curr.body.block3.conv]

            visualize_weights(modules, 'ori')
            # pytorch_total_params = sum(p.numel() for p in G_curr.parameters())
            # print(pytorch_total_params)

            for module in modules:
                m = prune.ln_structured(module, name="weight", amount=pruning_amount, n=1, dim=0)
                # m = prune.ln_structured(module, name="bias", amount=pruning_amount, n=1, dim=0)

        torch.save(G_curr.state_dict(), '%s/raw_prune_netG.pth' % (opt.outf))
        visualize_weights(modules, 'raw-prune')
        if cur_scale_level > 0:
            fake_Gs = Gs.copy()
            fake_Gs.append(G_curr) 
            fake_Zs = Zs.copy()
            fake_Zs.append(z_curr)
            fake_noise = NoiseAmp.copy()
            fake_noise.append(opt.noise_amp)
            fake_reals = reals[:cur_scale_level+1].copy()
            prune_SinGAN_generate(fake_Gs, fake_Zs, fake_reals, fake_noise, opt, gen_start_scale=0, num_samples=1, level=cur_scale_level)

        # Fine-tuning
        if fine_tune:
            G_curr = functions.reset_grads(G_curr, True)
            G_curr.train()

            if not structured:
                # Keep training using inherited weights
                z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter - warmup_steps, prune=True)
            else:
                # Training from scratch
                # G_curr.apply(models.weights_init)
                # D_curr.apply(models.weights_init)
                z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter, prune=True)
        G_curr = functions.reset_grads(G_curr,False)
        G_curr.eval()
        visualize_weights(modules, 'fine-tune')

        for m in modules:
            prune.remove(m, 'weight')
            if not structured:
              prune.remove(m, 'bias')
        
        # pytorch_total_params = sum(p.numel() for p in G_curr.parameters())
        # print(pytorch_total_params)

        #################################################################################
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        cur_scale_level+=1
        nfc_prev = opt.nfc
        del D_curr,G_curr
    return
Пример #17
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0  # iterator through the pyramid
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(
            opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128
        )  # every 4 levels in the pyr, double the filter number. 128 as maximum (not too wide.)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        if opt.fast_training:
            if (scale_num > 0) & (
                    scale_num % 4 == 0
            ):  # every 4 scales half the iteration number! (train less for the finer details. )
                opt.niter = opt.niter // 2

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (
                nfc_prev == opt.nfc
        ):  # if channel num match, then load the weights from last scale to init!
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp,
                                                  opt)  # real work

        G_curr = functions.reset_grads(
            G_curr,
            False)  # Gnet no longer requires gradient. Thus weight is frozen!
        G_curr.eval()
        D_curr = functions.reset_grads(
            D_curr, False)  # Note this D_curr is not the trained one...?
        D_curr.eval()

        Gs.append(
            G_curr
        )  # train append after train G at each scale. Note this G is no longer trainable!
        Zs.append(
            z_curr)  # what is Zs? collection of z_opt towards current layer.
        NoiseAmp.append(opt.noise_amp)  # Noise Amplitude is changed inside?

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Пример #18
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0
    #creating a pyramid of masks the same way we did for the img and thus to train on only the correct pixels
    #at all scales
    if opt.inpainting:
        m = functions.read_image_dir(
            '%s/%s_mask%s' %
            (opt.ref_dir, opt.input_name[:-4], opt.input_name[-4:]), opt)
        m = imresize(m, opt.scale1, opt)
        m_s = []  #pyramid of masks
        opt.m_s = functions.creat_reals_pyramid(m, m_s, opt)

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Пример #19
0
def SinGAN_anchor_generate(Gs,
                           Zs,
                           reals,
                           NoiseAmp,
                           opt,
                           in_s=None,
                           scale_v=1,
                           scale_h=1,
                           n=0,
                           gen_start_scale=0,
                           num_samples=1,
                           anchor_image=None,
                           direction=None,
                           transfer=None,
                           noise_solutions=None,
                           factor=None,
                           base=None,
                           insert_limit=0):

    #### Loading in Anchor if Needed #####
    anchor = anchor_image
    if anchor is not None:
        anchors = []
        anchor = functions.np2torch(anchor_image, opt)
        anchor_ = imresize(anchor, opt.scale1, opt)
        anchors = functions.creat_reals_pyramid(anchor_, anchors,
                                                opt)  #high key hacky code
    if direction is not None:
        directions = []
        direction = functions.np2torch(direction, opt)
        direction_ = imresize(direction, opt.scale1, opt)
        directions = functions.creat_reals_pyramid(direction_, directions,
                                                   opt)  #high key hacky code
    if base is not None:
        bases = []
        base = functions.np2torch(base, opt)
        base_ = imresize(base, opt.scale1, opt)
        bases = functions.creat_reals_pyramid(base_, bases,
                                              opt)  #high key hacky code
    #### MY CODE ####

    #if torch.is_tensor(in_s) == False:
    if in_s is None:
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    images_cur = []
    for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp):
        pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2
        m = nn.ZeroPad2d(int(pad1))
        nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v
        nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h

        images_prev = images_cur
        images_cur = []

        for i in range(0, num_samples, 1):
            if n == 0:  #COARSEST SCALE
                z_curr = functions.generate_noise([1, nzx, nzy],
                                                  device=opt.device)
                z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3])
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z, nzx, nzy],
                                                  device=opt.device)
                z_curr = m(z_curr)

            z_orig = z_curr

            if images_prev == []:  #FIRST GENERATION IN COARSEST SCALE
                I_prev = m(in_s)

            else:  #NOT FIRST GENERATION, BUT AT COARSEST SCALE
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)  #upscale
                #print(n)
                if opt.mode != "SR":
                    I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]),
                                    0:round(scale_h * reals[n].shape[3])]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]]
                    I_prev = functions.upsampling(
                        I_prev, z_curr.shape[2],
                        z_curr.shape[3])  #make it fit padded noise
                else:
                    #prev_before = I_prev #MY ADDITION
                    I_prev = m(I_prev)

            if n < gen_start_scale:  #anything less than final
                z_curr = Z_opt  #Z_opt comes from trained pyramid....
            z_in = noise_amp * (z_curr) + I_prev

            if noise_solutions is not None:
                z_curr = noise_solutions[n]

                z_in = (1 - factor) * noise_amp * (
                    z_curr
                ) + I_prev + factor * noise_amp * z_orig  #adds in previous image to z_opt'''

            I_curr = G(z_in.detach(), I_prev)
            if base is not None:
                if n == insert_limit:
                    I_curr = bases[n] * factor + I_curr * (1 - factor)

            if anchor is not None and direction is not None:
                anchor_curr = anchors[n]
                I_curr = reinforcement(anchor_curr, I_curr, directions[n])
                #I_curr = reinforcement_sigmoid(anchor_curr, I_curr, direction, n)
            ###### ENFORCE LH = ANCHOR FOR IMAGE #######

            if n == opt.stop_scale:  #hacky code
                if anchor is not None and direction is not None:
                    anchor_curr = anchors[n]
                    I_curr = reinforcement(anchor_curr, I_curr, direction)
                    #I_curr = reinforcement_sigmoid(anchor_curr, I_curr, direction, n)
                array = functions.convert_image_np(I_curr.detach())
            images_cur.append(I_curr)
        n += 1
    return array
Пример #20
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(
        opt)  #将输入的png图像转变为行 列 通道 像素这样的tensor之后,作归一化,值都在[-1,1]之间
    #通过norm的操作和clamp函数的功能
    #print(real_)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)  #开始对图像作变形
    reals = functions.creat_reals_pyramid(real, reals,
                                          opt)  #这个金字塔开始对图像的通道数作变化,以适应不同特征塔
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        #print('real_ value is {}'.format(real_))
        #print('reals value is {}'.format(reals))
        #print('scale_num value is {}'.format(scale_num))
        #print('stop_scale value is {}'.format(opt.stop_scale))
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        #print('opt.nfc value is {}'.format(opt.nfc))
        #print('opt.min_nfc value is {}'.format(opt.min_nfc))
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)  #ZS噪声图
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Пример #21
0
def invert_model(test_image,
                 model_name,
                 scales2invert=None,
                 penalty=1e-3,
                 show=True):
    '''test_image is an array, model_name is a name'''
    Noise_Solutions = []

    parser = get_arguments()
    parser.add_argument('--input_dir',
                        help='input image dir',
                        default='Input/Images')

    parser.add_argument('--mode', default='RandomSamples')
    opt = parser.parse_args("")
    opt.input_name = model_name
    opt.reg = penalty

    if model_name == 'islands2_basis_2.jpg':  #HARDCODED
        opt.scale_factor = 0.6

    opt = functions.post_config(opt)

    ### Loading in Generators
    Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt)
    for G in Gs:
        G = functions.reset_grads(G, False)
        G.eval()

    ### Loading in Ground Truth Test Images
    reals = []  #deleting old real images
    real = functions.np2torch(test_image, opt)
    functions.adjust_scales2image(real, opt)

    real_ = functions.np2torch(test_image, opt)
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)

    ### General Padding
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    m_noise = nn.ZeroPad2d(int(pad_noise))

    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    m_image = nn.ZeroPad2d(int(pad_image))

    I_prev = None
    REC_ERROR = 0

    if scales2invert is None:
        scales2invert = opt.stop_scale + 1

    for scale in range(scales2invert):
        #for scale in range(3):

        #Get X, G
        X = reals[scale]
        G = Gs[scale]
        noise_amp = NoiseAmp[scale]

        #Defining Dimensions
        opt.nc_z = X.shape[1]
        opt.nzx = X.shape[2]
        opt.nzy = X.shape[3]

        #getting parameters for prior distribution penalty
        pdf = torch.distributions.Normal(0, 1)
        alpha = opt.reg
        #alpha = 1e-2

        #Defining Z
        if scale == 0:
            z_init = functions.generate_noise(
                [1, opt.nzx, opt.nzy], device=opt.device)  #only 1D noise
        else:
            z_init = functions.generate_noise(
                [3, opt.nzx, opt.nzy],
                device=opt.device)  #otherwise move up to 3d noise

        z_init = Variable(z_init.cuda(),
                          requires_grad=True)  #variable to optimize

        #Building I_prev
        if I_prev == None:  #first scale scenario
            in_s = torch.full(reals[0].shape, 0, device=opt.device)  #all zeros
            I_prev = in_s
            I_prev = m_image(I_prev)  #padding

        else:  #otherwise take the output from the previous scale and upsample
            I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)  #upsamples
            I_prev = m_image(I_prev)
            I_prev = I_prev[:, :, 0:X.shape[2] + 10, 0:X.shape[
                3] + 10]  #making sure that precision errors don't mess anything up
            I_prev = functions.upsampling(I_prev, X.shape[2] + 10, X.shape[3] +
                                          10)  #seems to be redundant

        LR = [2e-3, 2e-2, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1]
        Zoptimizer = torch.optim.RMSprop([z_init],
                                         lr=LR[scale])  #Defining Optimizer
        x_loss = []  #for plotting
        epochs = []  #for plotting

        niter = [
            200, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400,
            400, 400
        ]
        for epoch in range(niter[scale]):  #Gradient Descent on Z

            if scale == 0:
                noise_input = m_noise(z_init.expand(1, 3, opt.nzx,
                                                    opt.nzy))  #expand and padd
            else:
                noise_input = m_noise(z_init)  #padding

            z_in = noise_amp * noise_input + I_prev
            G_z = G(z_in, I_prev)

            x_recLoss = F.mse_loss(G_z, X)  #MSE loss

            logProb = pdf.log_prob(z_init).mean()  #Gaussian loss

            loss = x_recLoss - (alpha * logProb.mean())

            Zoptimizer.zero_grad()
            loss.backward()
            Zoptimizer.step()

            #losses['rec'].append(x_recLoss.data[0])
            #print('Image loss: [%d] loss: %0.5f' % (epoch, x_recLoss.item()))
            #print('Noise loss: [%d] loss: %0.5f' % (epoch, z_recLoss.item()))
            x_loss.append(loss.item())
            epochs.append(epoch)

            REC_ERROR = x_recLoss

        if show:
            plt.plot(epochs, x_loss, label='x_loss')
            plt.legend()
            plt.show()

        I_prev = G_z.detach(
        )  #take final output, maybe need to edit this line something's very very fishy

        _ = show_image(X, show, 'target')
        reconstructed_image = show_image(I_prev, show, 'output')
        _ = show_image(noise_input.detach().cpu(), show, 'noise')

        Noise_Solutions.append(noise_input.detach())
    return Noise_Solutions, reconstructed_image, REC_ERROR
Пример #22
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    memory = []  ##storing memory
    time = []  ##storing time
    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))
        start = datetime.datetime.now()
        z_curr, in_s, G_curr, mbs, percent = train_single_scale(
            D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt)
        memory.append([mbs, percent])
        end = datetime.datetime.now()
        elapsed = end - start
        time.append(elapsed)
        print(f'time: {elapsed}')
        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))
        torch.save(full_memory, '%s/full_memory.pth' % (opt.out_))
        torch.save(full_time, '%s/full_time.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    #torch.save(full_memory, '%s/full_memory.pk' % (opt.out_))
    #torch.save(full_time, '%s/full_time.pk' % (opt.out_))
    print(memory)
    print(time)
    print(full_memory)
    print(full_time)
    #pk.dump(full_memory, open('Full_memory', 'wb'))
    return
Пример #23
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_images(opt)
    in_s = 0
    scale_num = 0
    real = [
        imresize(real_[0], opt.scale1, opt),
        imresize(real_[1], opt.scale1, opt)
    ]
    reals = [
        functions.creat_reals_pyramid(real[0], reals, opt),
        functions.creat_reals_pyramid(real[1], reals, opt)
    ]
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale1.png' % (opt.outf),
                   functions.convert_image_np(reals[0][scale_num]),
                   vmin=0,
                   vmax=1)
        plt.imsave('%s/real_scale2.png' % (opt.outf),
                   functions.convert_image_np(reals[1][scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr1, G_curr1 = init_models(opt)
        D_curr2, G_curr2 = init_models(opt)
        D_curr = [D_curr1, D_curr2]
        G_curr = [G_curr1, G_curr2]

        if (nfc_prev == opt.nfc):
            G_curr[0].load_state_dict(
                torch.load('%s/%d/netG1.pth' % (opt.out_, scale_num - 1)))
            D_curr[0].load_state_dict(
                torch.load('%s/%d/netD1.pth' % (opt.out_, scale_num - 1)))
            G_curr[1].load_state_dict(
                torch.load('%s/%d/netG2.pth' % (opt.out_, scale_num - 1)))
            D_curr[1].load_state_dict(
                torch.load('%s/%d/netD2.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr[0] = functions.reset_grads(G_curr[0], False)
        G_curr[0].eval()
        D_curr[0] = functions.reset_grads(D_curr[0], False)
        D_curr[0].eval()
        G_curr[1] = functions.reset_grads(G_curr[1], False)
        G_curr[1].eval()
        D_curr[1] = functions.reset_grads(D_curr[1], False)
        D_curr[1].eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return