Ejemplo n.º 1
0
def generate(model_name,
             anchor_image=None,
             direction=None,
             transfer=None,
             noise_solutions=None,
             factor=0.25,
             base=None,
             insert_limit=4):
    #direction = 'L, R, T, B'

    parser = get_arguments()
    parser.add_argument('--input_dir',
                        help='input image dir',
                        default='Input/Images')
    parser.add_argument('--mode',
                        help='random_samples | random_samples_arbitrary_sizes',
                        default='random_samples')
    # for random_samples:
    parser.add_argument('--gen_start_scale',
                        type=int,
                        help='generation start scale',
                        default=0)
    opt = parser.parse_args("")
    opt.input_name = model_name

    if model_name == 'islands2_basis_2.jpg':  #HARDCODED
        opt.scale_factor = 0.6

    opt = functions.post_config(opt)
    Gs = []
    Zs = []
    reals = []
    NoiseAmp = []

    real = functions.read_image(opt)
    #opt.input_name = anchor #CHANGE TO ANCHOR HERE
    anchor = functions.read_image(opt)

    functions.adjust_scales2image(real, opt)
    Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt)
    in_s = functions.generate_in2coarsest(reals, 1, 1, opt)

    array = SinGAN_anchor_generate(Gs,
                                   Zs,
                                   reals,
                                   NoiseAmp,
                                   opt,
                                   gen_start_scale=opt.gen_start_scale,
                                   anchor_image=anchor_image,
                                   direction=direction,
                                   transfer=transfer,
                                   noise_solutions=noise_solutions,
                                   factor=factor,
                                   base=base,
                                   insert_limit=insert_limit)
    return array
Ejemplo n.º 2
0
def train_model(input_name):
    parser = get_arguments()
    parser.add_argument('--input_dir',
                        help='input image dir',
                        default='Input/Images')
    parser.add_argument('--input_name', help='input image name', required=True)
    parser.add_argument('--mode', help='task to be done', default='train')
    opt = parser.parse_args("")
    opt = functions.post_config(opt)
    Gs = []
    Zs = []
    reals = []
    NoiseAmp = []
    dir2save = functions.generate_dir2save(opt)

    if (os.path.exists(dir2save)):
        print('trained model already exist')
    else:
        try:
            os.makedirs(dir2save)
        except OSError:
            pass
        real = functions.read_image(opt)
        functions.adjust_scales2image(real, opt)
        train(opt, Gs, Zs, reals, NoiseAmp)
        SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt)
Ejemplo n.º 3
0
    def __init__(self, input_name, load_existing_model):
        self.input_name = input_name
        self.load_existing_model = load_existing_model
        self.opt = Config(input_name)
        self.dir2save = functions.generate_dir2save(self.opt)
        self.real = functions.read_image(self.opt)
        functions.adjust_scales2image(self.real, self.opt)
        dir_exists = os.path.exists(self.dir2save)
        assert (not load_existing_model) or dir_exists, "cannot find trained model"

        if load_existing_model:
            print("Trained model has been loaded (not really)")
            self.Gs = torch.load(f'{self.dir2save}/Gs.pth')
            self.Zs = torch.load(f'{self.dir2save}/Zs.pth')
            self.reals = torch.load(f'{self.dir2save}/reals.pth')
            self.NoiseAmp = torch.load(f'{self.dir2save}/NoiseAmp.pth')
            self.is_loaded = True
        else:
            if dir_exists:
                user_input = input("Trained model has been found, type \"yes\" to overwrite: ")
                assert user_input == 'yes', "train aborted"
                rmtree(self.dir2save)
                print("train directory has been deleted")

            try:
                os.makedirs(self.dir2save)
            except OSError:
                pass
Ejemplo n.º 4
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        '''
        if (scale_num == opt.stop_scale):
            opt.nfc = 128
            opt.min_nfc = 128
        '''
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/in_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Ejemplo n.º 5
0
def train(opt,Gs,Zs,reals,NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    # cur_scale_level: current level from coarest to finest.
    cur_scale_level = 0
    # scale1: for the largest patch size, what ratio wrt the image shape
    reals = functions.creat_reals_pyramid(real_,reals,opt)
    nfc_prev = 0

    # Train including opt.stop_scale
    while cur_scale_level < opt.stop_scale+1:
        # nfc: number of out channels in conv block
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)

        # out_: output directory
        # outf: output folder, with scale
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,cur_scale_level)
        try:
            os.makedirs(opt.outf)
        except OSError:
                pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' %  (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1)

        D_curr,G_curr = init_models(opt)
        # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper)
        if (nfc_prev==opt.nfc):
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1)))

        # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor.
        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr,False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr,False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        cur_scale_level+=1
        nfc_prev = opt.nfc
        del D_curr,G_curr
        torch.cuda.empty_cache()
    return
Ejemplo n.º 6
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0
    netD_optimizer = tf.keras.optimizers.Adam(learning_rate=opt.lr_d,
                                              beta_1=opt.beta1,
                                              beta_2=0.999)
    netG_optimizer = tf.keras.optimizers.Adam(learning_rate=opt.lr_g,
                                              beta_1=opt.beta1,
                                              beta_2=0.999)

    while scale_num < opt.stop_scale + 1:

        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)

        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)
        D_curr, G_curr = init_models(opt)
        if nfc_prev == opt.nfc:
            D_curr.load_weights('%s/%d/netD' % (opt.out_, scale_num - 1))
            G_curr.load_weights('%s/%d/netG' % (opt.out_, scale_num - 1))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt,
                                                  scale_num, netG_optimizer,
                                                  netD_optimizer)

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)
        with open('%s/Zs.pkl' % (opt.out_), 'wb') as f:
            pickle.dump(Zs, f)
        with open('%s/reals.pkl' % (opt.out_), 'wb') as f:
            pickle.dump(reals, f)
        with open('%s/NoiseAmp.pkl' % (opt.out_), 'wb') as f:
            pickle.dump(NoiseAmp, f)
        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return None
Ejemplo n.º 7
0
def main(opt, generate=True):
    Gs = []
    Zs: List[Tuple] = []
    reals1 = []
    reals2 = []
    NoiseAmp = []
    dir2save = functions.generate_dir2save(opt)

    if (os.path.exists(dir2save)):
        print('trained model already exist')
    else:
        try:
            os.makedirs(dir2save)
        except OSError:
            pass
        _configure_logger(dir2save)

        # dump configuration file to json
        with open(os.path.join(f"{dir2save}", "config.json"), "w") as fp:
            config_dict = {k: str(v) for k, v in opt.__dict__.items()}
            json.dump(config_dict, fp)

        try:
            real1 = functions.read_image(opt, image_name=opt.input_name1)
            real2 = functions.read_image(opt, image_name=opt.input_name2)
            functions.adjust_scales2image(real1, opt)
            functions.adjust_scales2image(real2, opt)
            train(opt, Gs, Zs, reals1, reals2, NoiseAmp)
            logger.info("Done training")
            if generate:
                logger.info("Generating random samples")
                SinGAN_generate(Gs, Zs, reals1, reals2, NoiseAmp, opt)
        except Exception as e:
            logger.exception("Failed")
            raise
        finally:
            logger.info("Cleaning logger")
            _cleanup_logger()
Ejemplo n.º 8
0
def preprocess_content_image(opt, reals,scale):
    real = functions.read_image(opt)
    functions.adjust_scales2image(real, opt)
    ref = functions.read_image_dir('%s/%s' % (opt.ref_dir, opt.ref_name), opt)
    if ref.shape[3] != real.shape[3]:
        ref = imresize_to_shape(ref, [real.shape[2], real.shape[3]], opt)
        ref = ref[:, :, :real.shape[2], :real.shape[3]]

    N = len(reals) - 1
    n = scale
    in_s = imresize(ref, pow(opt.scale_factor, (N - n + 1)), opt)
    in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]]
    in_s = imresize(in_s, 1 / opt.scale_factor, opt)
    in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]]

    return in_s
Ejemplo n.º 9
0
def main(opt):
    Gs = []
    Zs = []
    reals = []
    NoiseAmp = []
    dir2save = functions.generate_dir2save(opt)

    if os.path.exists(dir2save):
        logger.info("Trained model directory already exists")
    else:
        try:
            os.makedirs(dir2save)
        except OSError:
            pass
        real = functions.read_image(opt)
        functions.adjust_scales2image(real, opt)
        train(opt, Gs, Zs, reals, NoiseAmp)
        SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt)
Ejemplo n.º 10
0
def train(opt,Gs,Zs,reals,NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_,opt.scale1,opt)
    print ("real.shape:",real.shape)  #real.shape: torch.Size([1, 3, 186, 248])
    reals = functions.creat_reals_pyramid(real,reals,opt)
    for i in reals: print (i.shape)
    '''
    torch.Size([1, 3, 20, 27])
    torch.Size([1, 3, 27, 36])
    torch.Size([1, 3, 35, 47])
    torch.Size([1, 3, 47, 62])
    torch.Size([1, 3, 61, 82])
    torch.Size([1, 3, 81, 108])
    torch.Size([1, 3, 107, 143])
    torch.Size([1, 3, 141, 188])
    torch.Size([1, 3, 186, 248])
    '''
    nfc_prev = 0
    print ("total %d scales.."% (opt.stop_scale))
    while scale_num<opt.stop_scale+1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        if opt.fast_training:
            if (scale_num > 0) & (scale_num % 4==0):
                opt.niter = opt.niter//2

        '''
        if (scale_num == opt.stop_scale):
            opt.nfc = 128
            opt.min_nfc = 128
        '''
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,scale_num)
        print ("out dir:",opt.outf)
        try:
            os.makedirs(opt.outf)
        except OSError:
                pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' %  (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1)
        #生成D和G,由于是全卷积网络,不需要关心输入大小
        D_curr,G_curr = init_models(opt)
        if (nfc_prev==opt.nfc):  #如果两次网络中的层数一样,则可以finetune上一个scale的网络参数
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,scale_num-1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1)))
            
        #对该scale的网络进行训练,这里每训练一个scale的网络就换,保持显存占用一直很低,不到3g
        z_curr,in_s,G_curr = train_single_scale(D_curr,G_curr,reals,Gs,Zs,in_s,NoiseAmp,opt)

        G_curr = functions.reset_grads(G_curr,False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr,False)
        D_curr.eval()
        
        #保留训练后的G网络
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num+=1
        nfc_prev = opt.nfc
        
        #这里将D和G网络删除,减小显存?
        del D_curr,G_curr
    return
Ejemplo n.º 11
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    #If training for inpainting
    if opt.mode == "inpainting":
        #Importing mask image in space [0,255]

        if opt.on_drive != None:
            mask = img.imread('%s/%s/%s' %
                              (opt.on_drive, opt.input_dir, opt.mask_name))
        else:
            mask = img.imread('%s/%s' % (opt.input_dir, opt.mask_name))

        #Convert mask to [O,1] space, 0 is masked out area, 1 everywhere else
        mask = 1 - (mask / 255)
        #Loading mask to torch tensor
        mask = torch.from_numpy(mask)

        #Resizing the initial mask
        mask = mask[:, :, :, None].view([1, 3, mask.shape[0], mask.shape[1]])
        mask = imresize(mask, opt.scale1, opt)

        #Creating mask pyramid
        opt.masks = functions.creat_reals_pyramid(mask, [], opt)

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        #plt.imsave('%s/real_scale.png' %  (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Ejemplo n.º 12
0
def train(opt, Gs, Zs, reals, crops, masks, NoiseAmp):
    real_ = functions.read_image(opt)
    real = imresize(real_, opt.scale1, opt)

    #real, _ , _ = functions.random_crop(real, opt.crop_size)
    mask_ = functions.read_mask(opt)
    #eye_ = functions.generate_eye_mask(opt, mask_, 0)
    crop_ = torch.zeros(
        (1, 1, opt.crop_size,
         opt.crop_size))  #Used just for size reference when downsizing
    #eye_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
    eye_color = functions.get_eye_color(real)
    opt.eye_color = eye_color
    #torch.autograd.set_detect_anomaly(True)

    in_s = 0
    scale_num = 0

    reals = functions.create_pyramid(real, reals, opt)
    masks = functions.create_pyramid(mask_, masks, opt, mode="mask")
    #eyes = functions.create_pyramid(eye_,eyes,opt, mode = "mask")

    # Shortcut to get sizes of corresponding crops for each scale
    crops = functions.create_pyramid(crop_, crops, opt, mode="mask")

    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, crops,
                                                  masks, eye_color, Gs, Zs,
                                                  in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr

    return
Ejemplo n.º 13
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)

    # 不同规格数据形成的列表
    reals = functions.creat_reals_pyramid(real, reals, opt)
    # print('reals', reals)  # 各个scale的图形形成的列表

    # plt.imsave('Output/real_scale_0.png', functions.convert_image_np(reals[0]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_1.png', functions.convert_image_np(reals[1]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_2.png', functions.convert_image_np(reals[2]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_3.png', functions.convert_image_np(reals[3]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_4.png', functions.convert_image_np(reals[4]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_5.png', functions.convert_image_np(reals[5]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_6.png', functions.convert_image_np(reals[6]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_7.png', functions.convert_image_np(reals[7]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_8.png', functions.convert_image_np(reals[8]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_9.png', functions.convert_image_np(reals[9]), vmin=0, vmax=1)

    nfc_prev = 0

    # opt.stop_scale = 9   循环9次
    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        print('opt.nfc', opt.nfc)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        print('opt.min_nfc', opt.min_nfc)
        if opt.fast_training:
            if (scale_num > 0) & (scale_num % 4 == 0):
                opt.niter = opt.niter // 2

        # out_是生成根路径
        opt.out_ = functions.generate_dir2save(opt)
        # outf是每个scale路径
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        # plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        # 保存原始图像
        plt.imsave('%s/original.png' % (opt.out_),
                   functions.convert_image_np(real_),
                   vmin=0,
                   vmax=1)
        # 在每个scale中保存real_scale
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        # return netD, netG  目前的D和G,D_curr
        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        # train_single_scale()返回:z_opt, in_s, netG
        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        print('G_curr', G_curr)
        G_curr.eval()
        print(G_curr.eval())
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Ejemplo n.º 14
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    print('train() current parameters')
    print(opt)
    real_ = functions.read_image(opt)
    in_s = 0
    if 'scale_num' in opt and opt.scale_num > 0:
        # EXPERIMENTAL: if we are in 'continue' mode
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    else:
        opt.scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.create_reals_pyramid(real, reals, opt)
    if 'nfc_prev' not in opt:
        opt.nfc_prev = 0

    while opt.scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(opt.scale_num / 4)),
                      128)
        opt.min_nfc = min(
            opt.min_nfc_init * pow(2, math.floor(opt.scale_num / 4)), 128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, opt.scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            print('directory %s already exists' % opt.outf)
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % opt.outf,
                   functions.convert_image_np(reals[opt.scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if opt.nfc_prev == opt.nfc:
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, opt.scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, opt.scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % opt.out_)
        torch.save(Gs, '%s/Gs.pth' % opt.out_)
        torch.save(reals, '%s/reals.pth' % opt.out_)
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % opt.out_)

        opt.scale_num += 1
        opt.nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Ejemplo n.º 15
0
    if dir2save is None:
        print('task does not exist')
    elif (os.path.exists(dir2save)):
        if opt.mode == 'random_samples':
            print(
                'random samples for image %s, start scale=%d, already exist' %
                (dir_name, opt.gen_start_scale))
        elif opt.mode == 'random_samples_arbitrary_sizes':
            print(
                'random samples for image %s at size: scale_h=%f, scale_v=%f, already exist'
                % (dir_name, opt.scale_h, opt.scale_v))
    else:
        try:
            os.makedirs(dir2save)
        except OSError:
            pass
        if opt.mode == 'random_samples':
            real1 = functions.read_image(opt, opt.input_name1)
            real2 = functions.read_image(opt, opt.input_name2)
            functions.adjust_scales2image(real1, opt)
            functions.adjust_scales2image(real2, opt)
            Gs, Zs, reals1, reals2, NoiseAmp = functions.load_trained_pyramid(
                opt)
            SinGAN_generate(Gs,
                            Zs,
                            reals1,
                            reals2,
                            NoiseAmp,
                            opt,
                            gen_start_scale=opt.gen_start_scale)
    elif (os.path.exists(dir2save)):  # 이미 폴더가 있는 경우
        if opt.mode == 'random_samples':
            print(
                'random samples for image %s, start scale=%d, already exist' %
                (opt.input_name, opt.gen_start_scale))
        elif opt.mode == 'random_samples_arbitrary_sizes':
            print(
                'random samples for image %s at size: scale_h=%f, scale_v=%f, already exist'
                % (opt.input_name, opt.scale_h, opt.scale_v))
    else:
        try:
            os.makedirs(dir2save)  # 폴더를 만들어줌
        except OSError:
            pass
        if opt.mode == 'random_samples':
            real = functions.read_image(
                opt)  # opt.input_dir 과 opt.input_name을 이용하여 이미지를 array형식으로 받아옴
            functions.adjust_scales2image(real, opt)
            Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt)
            in_s = functions.generate_in2coarsest(reals, 1, 1, opt)
            SinGAN_generate(Gs,
                            Zs,
                            reals,
                            NoiseAmp,
                            opt,
                            gen_start_scale=opt.gen_start_scale)

        elif opt.mode == 'random_samples_arbitrary_sizes':
            real = functions.read_image(
                opt)  # opt.input_dir 과 opt.input_name을 이용하여 이미지를 array형식으로 받아옴
            functions.adjust_scales2image(real, opt)  #opt를 설정해줌
            Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(
Ejemplo n.º 17
0
    elif (os.path.exists(dir2save)):
        if opt.mode == 'random_samples':
            print(
                'random samples for image %s, start scale=%d, already exist' %
                (opt.input_name, opt.gen_start_scale))
        elif opt.mode == 'random_samples_arbitrary_sizes':
            print(
                'random samples for image %s at size: scale_h=%f, scale_v=%f, already exist'
                % (opt.input_name, opt.scale_h, opt.scale_v))
    else:
        try:
            os.makedirs(dir2save)
        except OSError:
            pass
        if opt.mode == 'random_samples':
            real = functions.read_image(opt)
            functions.adjust_scales2image(real, opt)
            Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt)

            # 生成最粗糙的图像
            in_s = functions.generate_in2coarsest(reals, 1, 1, opt)
            SinGAN_generate(Gs,
                            Zs,
                            reals,
                            NoiseAmp,
                            opt,
                            gen_start_scale=opt.gen_start_scale)

        # elif opt.mode == 'random_samples_arbitrary_sizes':
        #     real = functions.read_image(opt)
        #     functions.adjust_scales2image(real, opt)
Ejemplo n.º 18
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    # cur_scale_level: current level from coarest to finest.
    cur_scale_level = 0
    # real = imresize(real_,opt.scale1,opt)
    # scale1: for the largest patch size, what ratio wrt the image shape
    reals = functions.creat_reals_pyramid(real_, reals, opt)
    upsamples = [reals[0]]
    diffs = [reals[0]]

    # Need to generate opt.stop_scale, thus upsample opt.stop_scale-1
    for i in range(opt.stop_scale):
        cur_img = reals[i]
        next_img = reals[i + 1]
        _, b, c, d = next_img.shape
        upsampled_real = imresize_to_shape(cur_img, (c, d, b), opt)
        upsamples.append(upsampled_real)
        diff = (next_img - upsampled_real).abs() - 1  # [-1, 1]
        diffs.append(diff)

    # Train including opt.stop_scale
    while cur_scale_level < opt.stop_scale + 1:
        # nfc: number of out channels in conv block
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)),
                      128)
        opt.min_nfc = min(
            opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)

        # out_: output directory
        # outf: output folder, with scale
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, cur_scale_level)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[cur_scale_level]),
                   vmin=0,
                   vmax=1)
        plt.imsave('%s/diff.png' % (opt.outf),
                   functions.convert_image_np(diffs[cur_scale_level].detach()),
                   vmin=0,
                   vmax=1)
        plt.imsave('%s/upsampled.png' % (opt.outf),
                   functions.convert_image_np(
                       upsamples[cur_scale_level].detach()),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper)

        # No need to reload, since training in parallel
        # if (nfc_prev==opt.nfc):
        #     G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1)))
        #     D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1)))

        # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor.
        # z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt)
        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals,
                                                  upsamples, cur_scale_level,
                                                  in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        cur_scale_level += 1
        # nfc_prev = opt.nfc
        del D_curr, G_curr
        torch.cuda.empty_cache()
    return
Ejemplo n.º 19
0
def get_reals(reals, opt, image_name):
    real_ = functions.read_image(opt, image_name)
    real = imresize(real_,opt.scale1,opt)
    reals = functions.creat_reals_pyramid(real,reals,opt)
    return reals
Ejemplo n.º 20
0
def SinGAN_generate(Gs,
                    Zs,
                    reals,
                    NoiseAmp,
                    opt,
                    in_s=None,
                    scale_v=1,
                    scale_h=1,
                    n=0,
                    gen_start_scale=0,
                    num_samples=50):
    real = functions.read_image(opt)

    real = real.numpy()
    real = resize(real, reals[-1].shape)
    real = torch.from_numpy(real)

    new_reals = creat_reals_pyramid(real, [], opt)
    buffer = []

    for new_real, real in zip(new_reals, reals):
        ele = new_real.numpy()
        ele = resize(ele, real.shape)
        ele = torch.from_numpy(ele)
        buffer.append(ele)
    reals = buffer

    for i, real_img in enumerate(reals):
        dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (
            opt.out, opt.input_name[:-4], gen_start_scale)
        plt.imsave('%s/%s_%d.png' % (dir2save, "real", i),
                   functions.convert_image_np(real_img.detach()),
                   vmin=0,
                   vmax=1)

    if in_s is None:
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    images_cur = []

    for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp):
        pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2
        m = nn.ZeroPad2d(int(pad1))
        nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v
        nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h
        # For Section IV
        # if n == 0:
        #     images_prev = images_cur
        # else:
        #     new_img_prev = []
        #     for img in images_cur:
        #         ele = reals[n].numpy()
        #         ele = resize(ele, img.shape)
        #         ele = torch.from_numpy(ele)
        #         new_img_prev.append(ele)
        #     images_prev = new_img_prev

        images_prev = images_cur

        # if n != 0:
        #     dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale)
        #     plt.imsave('%s/%s_%d.png' % (dir2save, "img_cur", n), functions.convert_image_np(images_prev[0].detach()), vmin=0,vmax=1)
        #     plt.imsave('%s/%s_%d.png' % (dir2save, "img_prev", n), functions.convert_image_np(images_cur[0].detach()), vmin=0,vmax=1)

        images_cur = []

        for i in range(0, num_samples, 1):
            if n == 0:
                z_curr = functions.generate_noise([1, nzx, nzy],
                                                  device=opt.device)
                z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3])
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z, nzx, nzy],
                                                  device=opt.device)
                z_curr = m(z_curr)

            if images_prev == []:
                I_prev = m(in_s)

            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)
                if opt.mode != "SR":
                    I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]),
                                    0:round(scale_h * reals[n].shape[3])]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]]
                    I_prev = functions.upsampling(I_prev, z_curr.shape[2],
                                                  z_curr.shape[3])
                else:
                    I_prev = m(I_prev)

            if n < gen_start_scale:
                z_curr = Z_opt

            z_in = noise_amp * (z_curr) + I_prev
            if opt.skip != '' and int(opt.skip) == n:
                I_curr = I_prev
            else:
                I_curr = G(z_in.detach(), I_prev)

            if n == len(reals) - 1:
                if opt.mode == 'train':
                    dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (
                        opt.out, opt.input_name[:-4], gen_start_scale)
                else:
                    dir2save = functions.generate_dir2save(opt)
                try:
                    os.makedirs(dir2save)
                except OSError:
                    pass
                if (opt.mode != "harmonization") & (opt.mode != "editing") & (
                        opt.mode != "SR") & (opt.mode != "paint2image"):
                    plt.imsave('%s/%d.png' % (dir2save, i),
                               functions.convert_image_np(I_curr.detach()),
                               vmin=0,
                               vmax=1)
                    # plt.imsave('%s/%d_%d.png' % (dir2save, i, n), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1)

            # For Section VI
            # if opt.mode == 'train':
            #     dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale)
            # else:
            #     dir2save = functions.generate_dir2save(opt)
            # try:
            #     os.makedirs(dir2save)
            # except OSError:
            #     pass
            # if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"):
            #     plt.imsave('%s/%d_%d.png' % (dir2save, i, n), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1)

            images_cur.append(I_curr)
        n += 1
    return I_curr.detach()
Ejemplo n.º 21
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0  # iterator through the pyramid
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(
            opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128
        )  # every 4 levels in the pyr, double the filter number. 128 as maximum (not too wide.)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        if opt.fast_training:
            if (scale_num > 0) & (
                    scale_num % 4 == 0
            ):  # every 4 scales half the iteration number! (train less for the finer details. )
                opt.niter = opt.niter // 2

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (
                nfc_prev == opt.nfc
        ):  # if channel num match, then load the weights from last scale to init!
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp,
                                                  opt)  # real work

        G_curr = functions.reset_grads(
            G_curr,
            False)  # Gnet no longer requires gradient. Thus weight is frozen!
        G_curr.eval()
        D_curr = functions.reset_grads(
            D_curr, False)  # Note this D_curr is not the trained one...?
        D_curr.eval()

        Gs.append(
            G_curr
        )  # train append after train G at each scale. Note this G is no longer trainable!
        Zs.append(
            z_curr)  # what is Zs? collection of z_opt towards current layer.
        NoiseAmp.append(opt.noise_amp)  # Noise Amplitude is changed inside?

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Ejemplo n.º 22
0
def train(opt,Gs,Zs,reals,NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    # cur_scale_level: current level from coarest to finest.
    cur_scale_level = 0
    # scale1: for the largest patch size, what ratio wrt the image shape
    reals = functions.creat_reals_pyramid(real_,reals,opt)
    nfc_prev = 0

    # Train including opt.stop_scale
    while cur_scale_level < opt.stop_scale+1:
        # nfc: number of out channels in conv block
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)

        # out_: output directory
        # outf: output folder, with scale
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,cur_scale_level)
        try:
            os.makedirs(opt.outf)
        except OSError:
                pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' %  (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1)

        D_curr,G_curr = init_models(opt)
        # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper)
        if (nfc_prev==opt.nfc):
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1)))

        # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor.
        if fine_tune:
          z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, warmup_steps)
        else:
          z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter)


        G_curr = functions.reset_grads(G_curr,False)
        # D_curr = functions.reset_grads(D_curr,False)
        G_curr.eval()
        # D_curr.eval()

        #################################################################################
        # Visualzie weights
        def visualize_weights(modules, fig_name):
            ori_weights = torch.tensor([]).cuda()
            for m in modules:
                cur_params = m.weight.data.flatten()
                ori_weights = torch.cat((ori_weights, cur_params))
                # cur_params = m.bias.data.flatten()
                # ori_weights = torch.cat((ori_weights, cur_params))
            sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement())
            print(sparsity, ori_weights.nelement())
            ori_weights = ori_weights.cpu().numpy()
            ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100)
            plt.savefig("%s/%s.png" % (opt.outf, fig_name))
            plt.close()

        # Pruning weights Structured or Non-structured
        if not structured:
            modules = [G_curr.head.conv, G_curr.head.norm,
                    G_curr.body.block1.conv, G_curr.body.block1.norm,
                    G_curr.body.block2.conv, G_curr.body.block2.norm,
                    G_curr.body.block3.conv, G_curr.body.block3.norm,
                    G_curr.tail[0]]
            parameters_to_prune = (
                (G_curr.head.conv, 'weight'),
                (G_curr.head.norm, 'weight'),
                (G_curr.body.block1.conv, 'weight'),
                (G_curr.body.block1.norm, 'weight'),
                (G_curr.body.block2.conv, 'weight'),
                (G_curr.body.block2.norm, 'weight'),
                (G_curr.body.block3.conv, 'weight'),
                (G_curr.body.block3.norm, 'weight'),
                (G_curr.tail[0], 'weight'),
                (G_curr.head.conv, 'bias'),
                (G_curr.head.norm, 'bias'),
                (G_curr.body.block1.conv, 'bias'),
                (G_curr.body.block1.norm, 'bias'),
                (G_curr.body.block2.conv, 'bias'),
                (G_curr.body.block2.norm, 'bias'),
                (G_curr.body.block3.conv, 'bias'),
                (G_curr.body.block3.norm, 'bias'),
                (G_curr.tail[0], 'bias'),
            )

            visualize_weights(modules, 'ori')

            # Prune weights
            prune.global_unstructured(
                parameters_to_prune,
                pruning_method=prune.L1Unstructured,
                amount=pruning_amount,
            )
        else:
            modules = [G_curr.head.conv,
            G_curr.body.block1.conv,
            G_curr.body.block2.conv,
            G_curr.body.block3.conv]

            visualize_weights(modules, 'ori')
            # pytorch_total_params = sum(p.numel() for p in G_curr.parameters())
            # print(pytorch_total_params)

            for module in modules:
                m = prune.ln_structured(module, name="weight", amount=pruning_amount, n=1, dim=0)
                # m = prune.ln_structured(module, name="bias", amount=pruning_amount, n=1, dim=0)

        torch.save(G_curr.state_dict(), '%s/raw_prune_netG.pth' % (opt.outf))
        visualize_weights(modules, 'raw-prune')
        if cur_scale_level > 0:
            fake_Gs = Gs.copy()
            fake_Gs.append(G_curr) 
            fake_Zs = Zs.copy()
            fake_Zs.append(z_curr)
            fake_noise = NoiseAmp.copy()
            fake_noise.append(opt.noise_amp)
            fake_reals = reals[:cur_scale_level+1].copy()
            prune_SinGAN_generate(fake_Gs, fake_Zs, fake_reals, fake_noise, opt, gen_start_scale=0, num_samples=1, level=cur_scale_level)

        # Fine-tuning
        if fine_tune:
            G_curr = functions.reset_grads(G_curr, True)
            G_curr.train()

            if not structured:
                # Keep training using inherited weights
                z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter - warmup_steps, prune=True)
            else:
                # Training from scratch
                # G_curr.apply(models.weights_init)
                # D_curr.apply(models.weights_init)
                z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter, prune=True)
        G_curr = functions.reset_grads(G_curr,False)
        G_curr.eval()
        visualize_weights(modules, 'fine-tune')

        for m in modules:
            prune.remove(m, 'weight')
            if not structured:
              prune.remove(m, 'bias')
        
        # pytorch_total_params = sum(p.numel() for p in G_curr.parameters())
        # print(pytorch_total_params)

        #################################################################################
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        cur_scale_level+=1
        nfc_prev = opt.nfc
        del D_curr,G_curr
    return
Ejemplo n.º 23
0
    Gs = []
    Zs = []
    reals = []
    NoiseAmp = []
    dir2save = functions.generate_dir2save(opt)
    if dir2save is None:
        print('task does not exist')
    else:
        try:
            os.makedirs(dir2save)
        except OSError:
            pass
        #real = functions.read_image(opt)
        #real = functions.adjust_scales2image(real, opt)
        #Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt)
        real1, real2 = functions.read_image(opt)
        real1 = functions.adjust_scales2image(real1, opt)
        real2 = functions.adjust_scales2image(real2, opt)
        Gs, Zs, reals1, reals2, NoiseAmp = functions.load_trained_pyramid(opt)
        if (opt.paint_start_scale < 1) | (opt.paint_start_scale > (len(Gs)-1)):
            print("injection scale should be between 1 and %d" % (len(Gs)-1))
        else:
            ref = functions.read_image_dir('%s/%s' % (opt.ref_dir, opt.ref_name), opt)
            if ref.shape[3] != real1.shape[3]:
                ref = imresize_to_shape(ref, [real1.shape[2], real1.shape[3]], opt)
                ref = ref[:, :, :real1.shape[2], :real1.shape[3]]

            N = len(reals1) - 1
            n = opt.paint_start_scale
            in_s = imresize(ref, pow(opt.scale_factor, (N - n + 1)), opt)
            in_s = in_s[:, :, :reals1[n - 1].shape[2], :reals1[n - 1].shape[3]]
Ejemplo n.º 24
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0
    #creating a pyramid of masks the same way we did for the img and thus to train on only the correct pixels
    #at all scales
    if opt.inpainting:
        m = functions.read_image_dir(
            '%s/%s_mask%s' %
            (opt.ref_dir, opt.input_name[:-4], opt.input_name[-4:]), opt)
        m = imresize(m, opt.scale1, opt)
        m_s = []  #pyramid of masks
        opt.m_s = functions.creat_reals_pyramid(m, m_s, opt)

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Ejemplo n.º 25
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    #from the name get the picture
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    #scale1 is defined from adjust2scale, saved in opt
    real = imresize(real_, opt.scale1, opt)
    # a list of resized images
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    #for scale 0 to stop scale
    for scale_num in tqdm_notebook(range(opt.stop_scale + 1),
                                   desc=opt.input_name,
                                   leave=True):

        #define the number of channels in this scale
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)

        #define the minimum number of channels in this scale
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        #the output main directory
        opt.out_ = functions.generate_dir2save(opt)

        #the output sub directory for each scale
        opt.outf = f'{opt.out_}/{scale_num}'

        #if need create the directory
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #save the resized original image for this scale
        plt.imsave(f'{opt.outf}/real_scale.png',
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        #create the generator and discriminator
        D_curr, G_curr = init_models(opt)

        #if the number of channel of previous layer = current nfc
        if (nfc_prev == opt.nfc):

            #direct load the weightfrom last model
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        #train a single scale, get the current z, in_s, generator
        z_curr, in_s, G_curr = train_single_scale(
            D_curr,  #current discriminator
            G_curr,  #current generator
            reals,  #the list of all resized data
            Gs,  #generator list
            Zs,  # a list initialized as []
            in_s,  #
            NoiseAmp,  #
            opt  #parameters
        )

        #make current G and D untrainable,set it into eval mode
        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        # save them into the list
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        #save the checkpoints
        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        nfc_prev = opt.nfc
        #delete the D and G for memory
        del D_curr, G_curr
    return
Ejemplo n.º 26
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(
        opt)  #将输入的png图像转变为行 列 通道 像素这样的tensor之后,作归一化,值都在[-1,1]之间
    #通过norm的操作和clamp函数的功能
    #print(real_)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)  #开始对图像作变形
    reals = functions.creat_reals_pyramid(real, reals,
                                          opt)  #这个金字塔开始对图像的通道数作变化,以适应不同特征塔
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        #print('real_ value is {}'.format(real_))
        #print('reals value is {}'.format(reals))
        #print('scale_num value is {}'.format(scale_num))
        #print('stop_scale value is {}'.format(opt.stop_scale))
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        #print('opt.nfc value is {}'.format(opt.nfc))
        #print('opt.min_nfc value is {}'.format(opt.min_nfc))
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)  #ZS噪声图
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Ejemplo n.º 27
0
def test_generate(model_name,
                  anchor_image=None,
                  direction=None,
                  transfer=None,
                  noise_solutions=None,
                  factor=0.25,
                  base=None,
                  insert_limit=4):
    #direction = 'L, R, T, B'

    parser = get_arguments()
    parser.add_argument('--input_dir',
                        help='input image dir',
                        default='Input/Images')
    parser.add_argument('--mode',
                        help='random_samples | random_samples_arbitrary_sizes',
                        default='random_samples')
    # for random_samples:
    parser.add_argument('--gen_start_scale',
                        type=int,
                        help='generation start scale',
                        default=0)
    opt = parser.parse_args("")
    opt.input_name = model_name

    opt = functions.post_config(opt)
    Gs = []
    Zs = []
    reals = []
    NoiseAmp = []

    opt.input_name = 'island_basis_0.jpg'  #grabbing image that exists...
    real = functions.read_image(opt)
    #opt.input_name = anchor #CHANGE TO ANCHOR HERE
    #anchor = functions.read_image(opt)
    functions.adjust_scales2image(real, opt)

    opt.input_name = 'test1.jpg'  #grabbing model that we want
    Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt)

    #dummy stuff for dimensions
    reals = []
    real_ = real
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    in_s = functions.generate_in2coarsest(reals, 1, 1, opt)

    array = SinGAN_anchor_generate(Gs,
                                   Zs,
                                   reals,
                                   NoiseAmp,
                                   opt,
                                   gen_start_scale=opt.gen_start_scale,
                                   anchor_image=anchor_image,
                                   direction=direction,
                                   transfer=transfer,
                                   noise_solutions=noise_solutions,
                                   factor=factor,
                                   base=base,
                                   insert_limit=insert_limit)
    return array
Ejemplo n.º 28
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    memory = []  ##storing memory
    time = []  ##storing time
    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))
        start = datetime.datetime.now()
        z_curr, in_s, G_curr, mbs, percent = train_single_scale(
            D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt)
        memory.append([mbs, percent])
        end = datetime.datetime.now()
        elapsed = end - start
        time.append(elapsed)
        print(f'time: {elapsed}')
        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))
        torch.save(full_memory, '%s/full_memory.pth' % (opt.out_))
        torch.save(full_time, '%s/full_time.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    #torch.save(full_memory, '%s/full_memory.pk' % (opt.out_))
    #torch.save(full_time, '%s/full_time.pk' % (opt.out_))
    print(memory)
    print(time)
    print(full_memory)
    print(full_time)
    #pk.dump(full_memory, open('Full_memory', 'wb'))
    return
Ejemplo n.º 29
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    # cur_scale_level: current level from coarest to finest.
    cur_scale_level = 0
    # scale1: for the largest patch size, what ratio wrt the image shape
    reals = functions.creat_reals_pyramid(real_, reals, opt)
    nfc_prev = 0

    # Train including opt.stop_scale
    while cur_scale_level < opt.stop_scale + 1:
        # nfc: number of out channels in conv block
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)),
                      128)
        opt.min_nfc = min(
            opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)

        # out_: output directory
        # outf: output folder, with scale
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, cur_scale_level)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[cur_scale_level]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, cur_scale_level - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, cur_scale_level - 1)))

        # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor.
        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()

        #################################################################################
        # Visualzie weights
        def visualize_weights(modules, fig_name):
            ori_weights = torch.tensor([]).cuda()
            for m in modules:
                cur_params = m.weight.data.flatten()
                ori_weights = torch.cat((ori_weights, cur_params))
                cur_params = m.bias.data.flatten()
                ori_weights = torch.cat((ori_weights, cur_params))
            # sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement())
            ori_weights = ori_weights.cpu().numpy()
            ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100)
            plt.savefig("%s/%s.png" % (opt.outf, fig_name))
            plt.close()

        # Pruning all weights
        modules = [
            G_curr.head.conv, G_curr.head.norm, G_curr.body.block1.conv,
            G_curr.body.block1.norm, G_curr.body.block2.conv,
            G_curr.body.block2.norm, G_curr.body.block3.conv,
            G_curr.body.block3.norm, G_curr.tail[0]
        ]
        parameters_to_prune = ((G_curr.head.conv, 'weight'), (G_curr.head.conv,
                                                              'bias'),
                               (G_curr.head.norm, 'weight'), (G_curr.head.norm,
                                                              'bias'),
                               (G_curr.body.block1.conv,
                                'weight'), (G_curr.body.block1.conv, 'bias'),
                               (G_curr.body.block1.norm,
                                'weight'), (G_curr.body.block1.norm, 'bias'),
                               (G_curr.body.block2.conv,
                                'weight'), (G_curr.body.block2.conv, 'bias'),
                               (G_curr.body.block2.norm,
                                'weight'), (G_curr.body.block2.norm, 'bias'),
                               (G_curr.body.block3.conv,
                                'weight'), (G_curr.body.block3.conv,
                                            'bias'), (G_curr.body.block3.norm,
                                                      'weight'),
                               (G_curr.body.block3.norm,
                                'bias'), (G_curr.tail[0],
                                          'weight'), (G_curr.tail[0], 'bias'))

        visualize_weights(modules, 'ori')

        # Prune weights
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=0.2,
        )

        for m in modules:
            prune.remove(m, 'weight')
            prune.remove(m, 'bias')

        visualize_weights(modules, 'prune')
        G_curr.half()
        #################################################################################
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        cur_scale_level += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
        torch.cuda.empty_cache()
    return