Ejemplo n.º 1
0
    def __init__(self, input_name, load_existing_model):
        self.input_name = input_name
        self.load_existing_model = load_existing_model
        self.opt = Config(input_name)
        self.dir2save = functions.generate_dir2save(self.opt)
        self.real = functions.read_image(self.opt)
        functions.adjust_scales2image(self.real, self.opt)
        dir_exists = os.path.exists(self.dir2save)
        assert (not load_existing_model) or dir_exists, "cannot find trained model"

        if load_existing_model:
            print("Trained model has been loaded (not really)")
            self.Gs = torch.load(f'{self.dir2save}/Gs.pth')
            self.Zs = torch.load(f'{self.dir2save}/Zs.pth')
            self.reals = torch.load(f'{self.dir2save}/reals.pth')
            self.NoiseAmp = torch.load(f'{self.dir2save}/NoiseAmp.pth')
            self.is_loaded = True
        else:
            if dir_exists:
                user_input = input("Trained model has been found, type \"yes\" to overwrite: ")
                assert user_input == 'yes', "train aborted"
                rmtree(self.dir2save)
                print("train directory has been deleted")

            try:
                os.makedirs(self.dir2save)
            except OSError:
                pass
Ejemplo n.º 2
0
def train_paint(opt, Gs, Zs, reals, NoiseAmp, centers, paint_inject_scale):
    in_s = torch.full(reals[0].shape, 0, device=opt.device)
    scale_num = 0
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        if scale_num != paint_inject_scale:
            scale_num += 1
            nfc_prev = opt.nfc
            continue
        else:
            opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
            opt.min_nfc = min(
                opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128)

            opt.out_ = functions.generate_dir2save(opt)
            opt.outf = '%s/%d' % (opt.out_, scale_num)
            try:
                os.makedirs(opt.outf)
            except OSError:
                pass

            #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
            #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
            plt.imsave('%s/in_scale.png' % (opt.outf),
                       functions.convert_image_np(reals[scale_num]),
                       vmin=0,
                       vmax=1)

            D_curr, G_curr = init_models(opt)

            z_curr, in_s, G_curr = train_single_scale(D_curr,
                                                      G_curr,
                                                      reals[:scale_num + 1],
                                                      Gs[:scale_num],
                                                      Zs[:scale_num],
                                                      in_s,
                                                      NoiseAmp[:scale_num],
                                                      opt,
                                                      centers=centers)

            G_curr = functions.reset_grads(G_curr, False)
            G_curr.eval()
            D_curr = functions.reset_grads(D_curr, False)
            D_curr.eval()

            Gs[scale_num] = G_curr
            Zs[scale_num] = z_curr
            NoiseAmp[scale_num] = opt.noise_amp

            torch.save(Zs, '%s/Zs.pth' % (opt.out_))
            torch.save(Gs, '%s/Gs.pth' % (opt.out_))
            torch.save(reals, '%s/reals.pth' % (opt.out_))
            torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

            scale_num += 1
            nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Ejemplo n.º 3
0
def train_model(input_name):
    parser = get_arguments()
    parser.add_argument('--input_dir',
                        help='input image dir',
                        default='Input/Images')
    parser.add_argument('--input_name', help='input image name', required=True)
    parser.add_argument('--mode', help='task to be done', default='train')
    opt = parser.parse_args("")
    opt = functions.post_config(opt)
    Gs = []
    Zs = []
    reals = []
    NoiseAmp = []
    dir2save = functions.generate_dir2save(opt)

    if (os.path.exists(dir2save)):
        print('trained model already exist')
    else:
        try:
            os.makedirs(dir2save)
        except OSError:
            pass
        real = functions.read_image(opt)
        functions.adjust_scales2image(real, opt)
        train(opt, Gs, Zs, reals, NoiseAmp)
        SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt)
Ejemplo n.º 4
0
def SinGAN_generate(Gs,Zs,reals,NoiseAmp,opt,in_s=None,scale_v=1,scale_h=1,n=0,gen_start_scale=0,num_samples=50):
    #if torch.is_tensor(in_s) == False:
    if in_s is None:
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    images_cur = []
    for G,Z_opt,noise_amp in zip(Gs,Zs,NoiseAmp):
        pad1 = ((opt.ker_size-1)*opt.num_layer)/2
        m = nn.ZeroPad2d(int(pad1))
        nzx = (Z_opt.shape[2]-pad1*2)*scale_v
        nzy = (Z_opt.shape[3]-pad1*2)*scale_h

        images_prev = images_cur
        images_cur = []

        for i in range(0,num_samples,1):
            if n == 0:
                z_curr = functions.generate_noise([1,nzx,nzy], device=opt.device)
                z_curr = z_curr.expand(1,3,z_curr.shape[2],z_curr.shape[3])
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device)
                z_curr = m(z_curr)

            if images_prev == []:
                I_prev = m(in_s)
                #I_prev = m(I_prev)
                #I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                #I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev,1/opt.scale_factor, opt)
                if opt.mode != "SR":
                    I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                    I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
                else:
                    I_prev = m(I_prev)

            if n < gen_start_scale:
                z_curr = Z_opt

            z_in = noise_amp*(z_curr)+I_prev
            I_curr = G(z_in.detach(),I_prev)

            if opt.mode == 'train':
                dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale)
            else:
                dir2save = functions.generate_dir2save(opt)
            try:
                os.makedirs(dir2save)
            except OSError:
                pass
            if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"):
                plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1)
                #plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1)
                #plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1)
            images_cur.append(I_curr)
        n+=1
    return I_curr.detach()
Ejemplo n.º 5
0
def SinGAN_SR(opt, Gs, Zs, reals, NoiseAmp):
    mode = opt.mode
    in_scale, iter_num = functions.calc_init_scale(opt)
    opt.scale_factor = 1 / in_scale
    opt.scale_factor_init = 1 / in_scale
    opt.mode = 'SR_train'
    #opt.alpha = 100
    opt.stop_scale = 0
    dir2trained_model = functions.generate_dir2save(opt)
    if (os.path.exists(dir2trained_model)):
        #print('Trained model does not exist, training SinGAN for SR')
        Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt)
        opt.mode = mode
    else:
        SR_train(opt, Gs, Zs, reals, NoiseAmp)
        opt.mode = mode
    print('%f' % pow(in_scale, iter_num))
    Zs_sr = []
    reals_sr = []
    NoiseAmp_sr = []
    Gs_sr = []
    real = reals[-1]  #read_image(opt)
    for j in range(1, iter_num + 1, 1):
        real_ = imresize(real, pow(1 / opt.scale_factor, j), opt)
        real_ = real_[:, :,
                      0:int(pow(1 / opt.scale_factor, j) * real.shape[2]),
                      0:int(pow(1 / opt.scale_factor, j) * real.shape[3])]
        reals_sr.append(real_)
        Gs_sr.append(Gs[-1])
        NoiseAmp_sr.append(NoiseAmp[-1])
        z_opt = torch.full(real_.shape, 0, device=opt.device)
        m = nn.ZeroPad2d(5)
        z_opt = m(z_opt)
        Zs_sr.append(z_opt)
    out = SinGAN_generate(Gs_sr,
                          Zs_sr,
                          reals_sr,
                          NoiseAmp_sr,
                          opt,
                          in_s=reals_sr[0],
                          num_samples=1)
    dir2save = functions.generate_dir2save(opt)
    plt.imsave('%s.png' % (dir2save),
               functions.convert_image_np(out.detach()),
               vmin=0,
               vmax=1)
    return
Ejemplo n.º 6
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        '''
        if (scale_num == opt.stop_scale):
            opt.nfc = 128
            opt.min_nfc = 128
        '''
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/in_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Ejemplo n.º 7
0
def train(opt,Gs,Zs,reals,NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    # cur_scale_level: current level from coarest to finest.
    cur_scale_level = 0
    # scale1: for the largest patch size, what ratio wrt the image shape
    reals = functions.creat_reals_pyramid(real_,reals,opt)
    nfc_prev = 0

    # Train including opt.stop_scale
    while cur_scale_level < opt.stop_scale+1:
        # nfc: number of out channels in conv block
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)

        # out_: output directory
        # outf: output folder, with scale
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,cur_scale_level)
        try:
            os.makedirs(opt.outf)
        except OSError:
                pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' %  (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1)

        D_curr,G_curr = init_models(opt)
        # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper)
        if (nfc_prev==opt.nfc):
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1)))

        # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor.
        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr,False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr,False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        cur_scale_level+=1
        nfc_prev = opt.nfc
        del D_curr,G_curr
        torch.cuda.empty_cache()
    return
Ejemplo n.º 8
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0
    netD_optimizer = tf.keras.optimizers.Adam(learning_rate=opt.lr_d,
                                              beta_1=opt.beta1,
                                              beta_2=0.999)
    netG_optimizer = tf.keras.optimizers.Adam(learning_rate=opt.lr_g,
                                              beta_1=opt.beta1,
                                              beta_2=0.999)

    while scale_num < opt.stop_scale + 1:

        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)

        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)
        D_curr, G_curr = init_models(opt)
        if nfc_prev == opt.nfc:
            D_curr.load_weights('%s/%d/netD' % (opt.out_, scale_num - 1))
            G_curr.load_weights('%s/%d/netG' % (opt.out_, scale_num - 1))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt,
                                                  scale_num, netG_optimizer,
                                                  netD_optimizer)

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)
        with open('%s/Zs.pkl' % (opt.out_), 'wb') as f:
            pickle.dump(Zs, f)
        with open('%s/reals.pkl' % (opt.out_), 'wb') as f:
            pickle.dump(reals, f)
        with open('%s/NoiseAmp.pkl' % (opt.out_), 'wb') as f:
            pickle.dump(NoiseAmp, f)
        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return None
Ejemplo n.º 9
0
def save_gif(opt, images_cur, alpha, beta):
    """
    images_cur is a list of time series images in same scale
    """
    dir2save = functions.generate_dir2save(opt)
    save_dir = os.path.join(f'{dir2save}', f'start_scale={start_scale:.2d}')
    try:
        os.makedirs(save_dir)
    except OSError:
        pass
    gif_path = os.path.join(save_dir,
                            f'alpha={alpha:.2f}_beta={beta:.2f}__.gif')
    imageio.mimsave(gif_path, images_cur, fps=10)
Ejemplo n.º 10
0
def main(opt):
    Gs = []
    Zs = []
    reals = []
    NoiseAmp = []
    dir2save = functions.generate_dir2save(opt)

    if os.path.exists(dir2save):
        logger.info("Trained model directory already exists")
    else:
        try:
            os.makedirs(dir2save)
        except OSError:
            pass
        real = functions.read_image(opt)
        functions.adjust_scales2image(real, opt)
        train(opt, Gs, Zs, reals, NoiseAmp)
        SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt)
Ejemplo n.º 11
0
def main(opt, generate=True):
    Gs = []
    Zs: List[Tuple] = []
    reals1 = []
    reals2 = []
    NoiseAmp = []
    dir2save = functions.generate_dir2save(opt)

    if (os.path.exists(dir2save)):
        print('trained model already exist')
    else:
        try:
            os.makedirs(dir2save)
        except OSError:
            pass
        _configure_logger(dir2save)

        # dump configuration file to json
        with open(os.path.join(f"{dir2save}", "config.json"), "w") as fp:
            config_dict = {k: str(v) for k, v in opt.__dict__.items()}
            json.dump(config_dict, fp)

        try:
            real1 = functions.read_image(opt, image_name=opt.input_name1)
            real2 = functions.read_image(opt, image_name=opt.input_name2)
            functions.adjust_scales2image(real1, opt)
            functions.adjust_scales2image(real2, opt)
            train(opt, Gs, Zs, reals1, reals2, NoiseAmp)
            logger.info("Done training")
            if generate:
                logger.info("Generating random samples")
                SinGAN_generate(Gs, Zs, reals1, reals2, NoiseAmp, opt)
        except Exception as e:
            logger.exception("Failed")
            raise
        finally:
            logger.info("Cleaning logger")
            _cleanup_logger()
Ejemplo n.º 12
0
def train(opt,Gs,Zs,reals1, reals2,NoiseAmp):
    logger.info("Starting to train...")

    reals1 = get_reals(reals1, opt, opt.input_name1)
    reals2 = get_reals(reals2, opt, opt.input_name2)
    in_s1 = 0
    in_s2 = 0
    scale_num = 0
    nfc_prev = 0

    while scale_num<opt.stop_scale+1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
                pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale1.png' %  (opt.outf), functions.convert_image_np(reals1[scale_num]), vmin=0, vmax=1)
        plt.imsave('%s/real_scale2.png' % (opt.outf), functions.convert_image_np(reals2[scale_num]), vmin=0, vmax=1)

        D_curr, D_mask1_curr, D_mask2_curr, G_curr = init_models(opt, reals1[len(Gs)].shape)
        if (nfc_prev==opt.nfc):
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,scale_num-1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1)))
            D_mask1_curr.load_state_dict(torch.load('%s/%d/netD_mask1.pth' % (opt.out_, scale_num - 1)))
            D_mask2_curr.load_state_dict(torch.load('%s/%d/netD_mask2.pth' % (opt.out_, scale_num - 1)))

        logger.info(f"Starting to train scale {scale_num}")
        z_curr_tuple, in_s_tuple, G_curr = train_single_scale(D_curr, D_mask1_curr, D_mask2_curr,G_curr,reals1, reals2, Gs,Zs,in_s1, in_s2,NoiseAmp,opt,
                                                              )
        in_s1, in_s2 = in_s_tuple
        logger.info(f"Done training scale {scale_num}")

        G_curr = functions.reset_grads(G_curr,False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr,False)
        D_curr.eval()
        D_mask1_curr = functions.reset_grads(D_mask1_curr,False)
        D_mask1_curr.eval()
        D_mask2_curr = functions.reset_grads(D_mask2_curr,False)
        D_mask2_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr_tuple)
        NoiseAmp.append((opt.noise_amp1, opt.noise_amp2))

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals1, '%s/reals1.pth' % (opt.out_))
        torch.save(reals2, '%s/reals2.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num+=1
        nfc_prev = opt.nfc
        del D_curr, D_mask1_curr, D_mask2_curr,G_curr
    return
Ejemplo n.º 13
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0
    #creating a pyramid of masks the same way we did for the img and thus to train on only the correct pixels
    #at all scales
    if opt.inpainting:
        m = functions.read_image_dir(
            '%s/%s_mask%s' %
            (opt.ref_dir, opt.input_name[:-4], opt.input_name[-4:]), opt)
        m = imresize(m, opt.scale1, opt)
        m_s = []  #pyramid of masks
        opt.m_s = functions.creat_reals_pyramid(m, m_s, opt)

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Ejemplo n.º 14
0
def SinGAN_generate(Gs,
                    Zs,
                    reals1,
                    reals2,
                    NoiseAmp,
                    opt,
                    in_s1=None,
                    in_s2=None,
                    scale_v=1,
                    scale_h=1,
                    n=0,
                    gen_start_scale=0,
                    num_samples=100):
    #if torch.is_tensor(in_s) == False:
    # if in_s is None:
    in_s = torch.full(reals1[0].shape, 0, device=opt.device)
    images_cur = []

    # assert len(reals1) == len(Gs)

    def random_noise_mode():
        prob = torch.rand(1)
        if prob < 0.5:
            noise_mode = NoiseMode.Z1
        else:
            noise_mode = NoiseMode.Z2
        return noise_mode

    noise_modes = [random_noise_mode() for _ in range(num_samples)]

    for G, (Z_opt1, Z_opt2), (noise_amp1, noise_amp2) in zip(Gs, Zs, NoiseAmp):
        pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2
        m = nn.ZeroPad2d(int(pad1))
        # assumption: same size
        nzx = (Z_opt1.shape[2] - pad1 * 2) * scale_v
        nzy = (Z_opt1.shape[3] - pad1 * 2) * scale_h

        images_prev = images_cur
        images_cur = []

        for i in range(0, num_samples, 1):
            z_curr = _generate_noise_for_sampling(m, n, nzx, nzy, opt,
                                                  noise_modes[i])

            if images_prev == []:
                I_prev = m(in_s)
                #I_prev = m(I_prev)
                #I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                #I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)
                I_prev = I_prev[:, :, 0:round(scale_v * reals1[n].shape[2]),
                                0:round(scale_h * reals1[n].shape[3])]
                I_prev = m(I_prev)
                I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]]
                I_prev = functions.upsampling(I_prev, z_curr.shape[2],
                                              z_curr.shape[3])

            if n < gen_start_scale:
                zero = torch.zeros(Z_opt1.shape)
                if noise_modes[i] == NoiseMode.Z1:
                    z_curr = functions.merge_noise_vectors(
                        Z_opt1, zero, opt.noise_vectors_merge_method)
                elif noise_modes[i] == NoiseMode.Z2:
                    z_curr = functions.merge_noise_vectors(
                        zero, Z_opt2, opt.noise_vectors_merge_method)
                else:
                    z_curr = functions.merge_noise_vectors(
                        Z_opt1, Z_opt2, opt.noise_vectors_merge_method)

            noise_amp = noise_amp1 if noise_modes[
                i] == NoiseMode.Z1 else noise_amp2
            z_in = noise_amp * (z_curr) + I_prev
            I_curr = G(z_in.detach(), I_prev)[0]

            if n == len(reals1) - 1:
                if opt.mode == 'train':
                    dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (
                        opt.out, opt.exp_name, gen_start_scale)
                else:
                    dir2save = functions.generate_dir2save(opt)
                try:
                    os.makedirs(dir2save)
                except OSError:
                    pass
                if (opt.mode != "harmonization") & (opt.mode != "editing") & (
                        opt.mode != "SR") & (opt.mode != "paint2image"):
                    print(f"Saving image: {i}")
                    plt.imsave(f'%s/%d_{noise_modes[i].name}.png' %
                               (dir2save, i),
                               functions.convert_image_np(I_curr.detach()),
                               vmin=0,
                               vmax=1)
                    #plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1)
                    #plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1)
            images_cur.append(I_curr)
        print(f"Done Generating level: {n}")
        n += 1
    return I_curr.detach()
Ejemplo n.º 15
0
def train(opt,Gs,Zs,reals,NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    # cur_scale_level: current level from coarest to finest.
    cur_scale_level = 0
    # scale1: for the largest patch size, what ratio wrt the image shape
    reals = functions.creat_reals_pyramid(real_,reals,opt)
    nfc_prev = 0

    # Train including opt.stop_scale
    while cur_scale_level < opt.stop_scale+1:
        # nfc: number of out channels in conv block
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)

        # out_: output directory
        # outf: output folder, with scale
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,cur_scale_level)
        try:
            os.makedirs(opt.outf)
        except OSError:
                pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' %  (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1)

        D_curr,G_curr = init_models(opt)
        # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper)
        if (nfc_prev==opt.nfc):
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1)))

        # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor.
        if fine_tune:
          z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, warmup_steps)
        else:
          z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter)


        G_curr = functions.reset_grads(G_curr,False)
        # D_curr = functions.reset_grads(D_curr,False)
        G_curr.eval()
        # D_curr.eval()

        #################################################################################
        # Visualzie weights
        def visualize_weights(modules, fig_name):
            ori_weights = torch.tensor([]).cuda()
            for m in modules:
                cur_params = m.weight.data.flatten()
                ori_weights = torch.cat((ori_weights, cur_params))
                # cur_params = m.bias.data.flatten()
                # ori_weights = torch.cat((ori_weights, cur_params))
            sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement())
            print(sparsity, ori_weights.nelement())
            ori_weights = ori_weights.cpu().numpy()
            ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100)
            plt.savefig("%s/%s.png" % (opt.outf, fig_name))
            plt.close()

        # Pruning weights Structured or Non-structured
        if not structured:
            modules = [G_curr.head.conv, G_curr.head.norm,
                    G_curr.body.block1.conv, G_curr.body.block1.norm,
                    G_curr.body.block2.conv, G_curr.body.block2.norm,
                    G_curr.body.block3.conv, G_curr.body.block3.norm,
                    G_curr.tail[0]]
            parameters_to_prune = (
                (G_curr.head.conv, 'weight'),
                (G_curr.head.norm, 'weight'),
                (G_curr.body.block1.conv, 'weight'),
                (G_curr.body.block1.norm, 'weight'),
                (G_curr.body.block2.conv, 'weight'),
                (G_curr.body.block2.norm, 'weight'),
                (G_curr.body.block3.conv, 'weight'),
                (G_curr.body.block3.norm, 'weight'),
                (G_curr.tail[0], 'weight'),
                (G_curr.head.conv, 'bias'),
                (G_curr.head.norm, 'bias'),
                (G_curr.body.block1.conv, 'bias'),
                (G_curr.body.block1.norm, 'bias'),
                (G_curr.body.block2.conv, 'bias'),
                (G_curr.body.block2.norm, 'bias'),
                (G_curr.body.block3.conv, 'bias'),
                (G_curr.body.block3.norm, 'bias'),
                (G_curr.tail[0], 'bias'),
            )

            visualize_weights(modules, 'ori')

            # Prune weights
            prune.global_unstructured(
                parameters_to_prune,
                pruning_method=prune.L1Unstructured,
                amount=pruning_amount,
            )
        else:
            modules = [G_curr.head.conv,
            G_curr.body.block1.conv,
            G_curr.body.block2.conv,
            G_curr.body.block3.conv]

            visualize_weights(modules, 'ori')
            # pytorch_total_params = sum(p.numel() for p in G_curr.parameters())
            # print(pytorch_total_params)

            for module in modules:
                m = prune.ln_structured(module, name="weight", amount=pruning_amount, n=1, dim=0)
                # m = prune.ln_structured(module, name="bias", amount=pruning_amount, n=1, dim=0)

        torch.save(G_curr.state_dict(), '%s/raw_prune_netG.pth' % (opt.outf))
        visualize_weights(modules, 'raw-prune')
        if cur_scale_level > 0:
            fake_Gs = Gs.copy()
            fake_Gs.append(G_curr) 
            fake_Zs = Zs.copy()
            fake_Zs.append(z_curr)
            fake_noise = NoiseAmp.copy()
            fake_noise.append(opt.noise_amp)
            fake_reals = reals[:cur_scale_level+1].copy()
            prune_SinGAN_generate(fake_Gs, fake_Zs, fake_reals, fake_noise, opt, gen_start_scale=0, num_samples=1, level=cur_scale_level)

        # Fine-tuning
        if fine_tune:
            G_curr = functions.reset_grads(G_curr, True)
            G_curr.train()

            if not structured:
                # Keep training using inherited weights
                z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter - warmup_steps, prune=True)
            else:
                # Training from scratch
                # G_curr.apply(models.weights_init)
                # D_curr.apply(models.weights_init)
                z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter, prune=True)
        G_curr = functions.reset_grads(G_curr,False)
        G_curr.eval()
        visualize_weights(modules, 'fine-tune')

        for m in modules:
            prune.remove(m, 'weight')
            if not structured:
              prune.remove(m, 'bias')
        
        # pytorch_total_params = sum(p.numel() for p in G_curr.parameters())
        # print(pytorch_total_params)

        #################################################################################
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        cur_scale_level+=1
        nfc_prev = opt.nfc
        del D_curr,G_curr
    return
Ejemplo n.º 16
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    memory = []  ##storing memory
    time = []  ##storing time
    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))
        start = datetime.datetime.now()
        z_curr, in_s, G_curr, mbs, percent = train_single_scale(
            D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt)
        memory.append([mbs, percent])
        end = datetime.datetime.now()
        elapsed = end - start
        time.append(elapsed)
        print(f'time: {elapsed}')
        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))
        torch.save(full_memory, '%s/full_memory.pth' % (opt.out_))
        torch.save(full_time, '%s/full_time.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    #torch.save(full_memory, '%s/full_memory.pk' % (opt.out_))
    #torch.save(full_time, '%s/full_time.pk' % (opt.out_))
    print(memory)
    print(time)
    print(full_memory)
    print(full_time)
    #pk.dump(full_memory, open('Full_memory', 'wb'))
    return
Ejemplo n.º 17
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0  # iterator through the pyramid
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(
            opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128
        )  # every 4 levels in the pyr, double the filter number. 128 as maximum (not too wide.)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        if opt.fast_training:
            if (scale_num > 0) & (
                    scale_num % 4 == 0
            ):  # every 4 scales half the iteration number! (train less for the finer details. )
                opt.niter = opt.niter // 2

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (
                nfc_prev == opt.nfc
        ):  # if channel num match, then load the weights from last scale to init!
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp,
                                                  opt)  # real work

        G_curr = functions.reset_grads(
            G_curr,
            False)  # Gnet no longer requires gradient. Thus weight is frozen!
        G_curr.eval()
        D_curr = functions.reset_grads(
            D_curr, False)  # Note this D_curr is not the trained one...?
        D_curr.eval()

        Gs.append(
            G_curr
        )  # train append after train G at each scale. Note this G is no longer trainable!
        Zs.append(
            z_curr)  # what is Zs? collection of z_opt towards current layer.
        NoiseAmp.append(opt.noise_amp)  # Noise Amplitude is changed inside?

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Ejemplo n.º 18
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(
        opt)  #将输入的png图像转变为行 列 通道 像素这样的tensor之后,作归一化,值都在[-1,1]之间
    #通过norm的操作和clamp函数的功能
    #print(real_)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)  #开始对图像作变形
    reals = functions.creat_reals_pyramid(real, reals,
                                          opt)  #这个金字塔开始对图像的通道数作变化,以适应不同特征塔
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        #print('real_ value is {}'.format(real_))
        #print('reals value is {}'.format(reals))
        #print('scale_num value is {}'.format(scale_num))
        #print('stop_scale value is {}'.format(opt.stop_scale))
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        #print('opt.nfc value is {}'.format(opt.nfc))
        #print('opt.min_nfc value is {}'.format(opt.min_nfc))
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)  #ZS噪声图
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Ejemplo n.º 19
0
def generate_gif(Gs,Zs,reals,NoiseAmp,opt,alpha=0.1,beta=0.9,start_scale=2,fps=10):

    in_s = torch.full(Zs[0].shape, 0, device=opt.device)
    images_cur = []
    count = 0

    for G,Z_opt,noise_amp,real in zip(Gs,Zs,NoiseAmp,reals):
        pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
        nzx = Z_opt.shape[2]
        nzy = Z_opt.shape[3]
        #pad_noise = 0
        #m_noise = nn.ZeroPad2d(int(pad_noise))
        m_image = nn.ZeroPad2d(int(pad_image))
        images_prev = images_cur
        images_cur = []
        if count == 0:
            z_rand = functions.generate_noise([1,nzx,nzy], device=opt.device)
            z_rand = z_rand.expand(1,3,Z_opt.shape[2],Z_opt.shape[3])
            z_prev1 = 0.95*Z_opt +0.05*z_rand
            z_prev2 = Z_opt
        else:
            z_prev1 = 0.95*Z_opt +0.05*functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device)
            z_prev2 = Z_opt

        for i in range(0,100,1):
            if count == 0:
                z_rand = functions.generate_noise([1,nzx,nzy], device=opt.device)
                z_rand = z_rand.expand(1,3,Z_opt.shape[2],Z_opt.shape[3])
                diff_curr = beta*(z_prev1-z_prev2)+(1-beta)*z_rand
            else:
                diff_curr = beta*(z_prev1-z_prev2)+(1-beta)*(functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device))

            z_curr = alpha*Z_opt+(1-alpha)*(z_prev1+diff_curr)
            z_prev2 = z_prev1
            z_prev1 = z_curr

            if images_prev == []:
                I_prev = in_s
            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)
                I_prev = I_prev[:, :, 0:real.shape[2], 0:real.shape[3]]
                I_prev = m_image(I_prev)
            if count < start_scale:
                z_curr = Z_opt

            z_in = noise_amp*z_curr+I_prev
            I_curr = G(z_in.detach(),I_prev)

            if (count == len(Gs)-1):
                I_curr = functions.denorm(I_curr).detach()
                I_curr = I_curr[0,:,:,:].cpu().numpy()
                I_curr = I_curr.transpose(1, 2, 0)*255
                I_curr = I_curr.astype(np.uint8)

            images_cur.append(I_curr)
        count += 1
    dir2save = functions.generate_dir2save(opt)
    try:
        os.makedirs('%s/start_scale=%d' % (dir2save,start_scale) )
    except OSError:
        pass
    imageio.mimsave('%s/start_scale=%d/alpha=%f_beta=%f.gif' % (dir2save,start_scale,alpha,beta),images_cur,fps=fps)
    del images_cur
Ejemplo n.º 20
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    print('train() current parameters')
    print(opt)
    real_ = functions.read_image(opt)
    in_s = 0
    if 'scale_num' in opt and opt.scale_num > 0:
        # EXPERIMENTAL: if we are in 'continue' mode
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    else:
        opt.scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.create_reals_pyramid(real, reals, opt)
    if 'nfc_prev' not in opt:
        opt.nfc_prev = 0

    while opt.scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(opt.scale_num / 4)),
                      128)
        opt.min_nfc = min(
            opt.min_nfc_init * pow(2, math.floor(opt.scale_num / 4)), 128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, opt.scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            print('directory %s already exists' % opt.outf)
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % opt.outf,
                   functions.convert_image_np(reals[opt.scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if opt.nfc_prev == opt.nfc:
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, opt.scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, opt.scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % opt.out_)
        torch.save(Gs, '%s/Gs.pth' % opt.out_)
        torch.save(reals, '%s/reals.pth' % opt.out_)
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % opt.out_)

        opt.scale_num += 1
        opt.nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Ejemplo n.º 21
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)

    # 不同规格数据形成的列表
    reals = functions.creat_reals_pyramid(real, reals, opt)
    # print('reals', reals)  # 各个scale的图形形成的列表

    # plt.imsave('Output/real_scale_0.png', functions.convert_image_np(reals[0]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_1.png', functions.convert_image_np(reals[1]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_2.png', functions.convert_image_np(reals[2]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_3.png', functions.convert_image_np(reals[3]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_4.png', functions.convert_image_np(reals[4]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_5.png', functions.convert_image_np(reals[5]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_6.png', functions.convert_image_np(reals[6]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_7.png', functions.convert_image_np(reals[7]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_8.png', functions.convert_image_np(reals[8]), vmin=0, vmax=1)
    # plt.imsave('Output/real_scale_9.png', functions.convert_image_np(reals[9]), vmin=0, vmax=1)

    nfc_prev = 0

    # opt.stop_scale = 9   循环9次
    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        print('opt.nfc', opt.nfc)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
        print('opt.min_nfc', opt.min_nfc)
        if opt.fast_training:
            if (scale_num > 0) & (scale_num % 4 == 0):
                opt.niter = opt.niter // 2

        # out_是生成根路径
        opt.out_ = functions.generate_dir2save(opt)
        # outf是每个scale路径
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        # plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        # 保存原始图像
        plt.imsave('%s/original.png' % (opt.out_),
                   functions.convert_image_np(real_),
                   vmin=0,
                   vmax=1)
        # 在每个scale中保存real_scale
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        # return netD, netG  目前的D和G,D_curr
        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        # train_single_scale()返回:z_opt, in_s, netG
        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        print('G_curr', G_curr)
        G_curr.eval()
        print(G_curr.eval())
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
 # for random_samples_arbitrary_sizes:
 parser.add_argument('--scale_h',
                     type=float,
                     help='horizontal resize factor for random samples',
                     default=1.5)
 parser.add_argument('--scale_v',
                     type=float,
                     help='vertical resize factor for random samples',
                     default=1)
 opt = parser.parse_args()
 opt = functions.post_config(opt)
 Gs = []
 Zs = []
 reals = []
 NoiseAmp = []
 dir2save = functions.generate_dir2save(opt)  #내가 만들어야할 폴더 이름을 반환해줌
 if dir2save is None:  #opt.mode가 잘못된 경우
     print('task does not exist')
 elif (os.path.exists(dir2save)):  # 이미 폴더가 있는 경우
     if opt.mode == 'random_samples':
         print(
             'random samples for image %s, start scale=%d, already exist' %
             (opt.input_name, opt.gen_start_scale))
     elif opt.mode == 'random_samples_arbitrary_sizes':
         print(
             'random samples for image %s at size: scale_h=%f, scale_v=%f, already exist'
             % (opt.input_name, opt.scale_h, opt.scale_v))
 else:
     try:
         os.makedirs(dir2save)  # 폴더를 만들어줌
     except OSError:
Ejemplo n.º 23
0
def SinGAN_generate(Gs,Zs,reals,NoiseAmp,opt,in_s=None,scale_v=1,scale_h=1,n=0,gen_start_scale=0,num_samples=50):
    #if torch.is_tensor(in_s) == False:
    passes = 0
    if in_s is None:
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    images_cur = []
    for G,Z_opt,noise_amp in zip(Gs,Zs,NoiseAmp):
        pad1 = ((opt.ker_size-1)*opt.num_layer)/2
        m = nn.ZeroPad2d(int(pad1))
        nzx = (Z_opt.shape[2]-pad1*2)*scale_v
        nzy = (Z_opt.shape[3]-pad1*2)*scale_h

        images_prev = images_cur
        images_cur = []

        for i in range(0,num_samples,1):
            if n == 0:
                z_curr = functions.generate_noise([1,nzx,nzy], device=opt.device)
                z_curr = z_curr.expand(1,3,z_curr.shape[2],z_curr.shape[3])
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device)
                z_curr = m(z_curr)
            
                
            if images_prev == []:
                print("in_s shape before padding with m", in_s.shape)
                I_prev = m(in_s)
                print("in_s shape after padding with m now I_prev", in_s.shape)
#                I_prev = m(I_prev)
#                I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
                print("I_prev shape after upsampling using noise shape", I_prev.shape)
            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev,1/opt.scale_factor, opt)
                if opt.mode != "SR":
                    I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                    
                    I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
                    
                else:
                    I_prev = m(I_prev)

            if n < gen_start_scale:
                z_curr = Z_opt
                
#                real = img.imread("D:\MVA\CompVision\Project\SinGAN-master\Input\Images/Salt_and_Pepper_Golden_Bridge_by_night (2).png")
                real = img.imread("D:\MVA\CompVision\Project\SinGAN-master\Input\Images/Noisy_Golden_Bridge_by_night.jpg")
#                real = img.imread("D:\MVA\CompVision\Project\SinGAN-master\Output\RandomSamples\Golden_Bridge_by_night/1.png")
                real = real[:,:,:,None]
                real = real.transpose((3,2,0,1))/255
                real = torch.from_numpy(real)
                real = move_to_gpu(real)
                real = real.type(torch.cuda.FloatTensor)
                real = ((real - 0.5)*2).clamp(-1,1)
                real = real[:,0:3,:,:]
                
#                real = imresize(real,1/opt.scale_factor, opt)
                
#                real = real[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])]
                real = m(real)
#                real = real[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                
                I_prev = functions.upsampling(real,z_curr.shape[2],z_curr.shape[3])

               
            print("I_prev",I_prev.shape)
            print("z_curr",z_curr.shape)
            print('---')
#            z_in = noise_amp*(z_curr)+I_prev
            z_in = 0*(z_curr)+I_prev
            I_curr = G(z_in.detach(),I_prev)

            if n == len(reals)-1:
                if opt.mode == 'train':
                    dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale)
                else:
                    dir2save = functions.generate_dir2save(opt)
                try:
                    os.makedirs(dir2save)
                except OSError:
                    pass
                if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"):
                    plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1)
#                    plt.imsave('%s/%d.png' % (dir2save, passes), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1)
#                    passes +=1 
                    #plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1)
                    #plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1)
            images_cur.append(I_curr)
        n+=1
    
#    plt.imsave('D:\MVA\CompVision\Project\SinGAN-master\Output\RandomSamples\Golden_Bridge_by_night\gen_start_scale=0\Denoised.png', functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1)
    return I_curr.detach()
Ejemplo n.º 24
0
    '--scale_h', "1", '--gen_start_scale', '1', '--scale_factor', '0.75'
])
opt = functions.post_config(opt)
#%%
opt = parser.parse_args([
    '--mode', "random_samples", "--input_name", "mountains.jpg", '--scale_h',
    "1", '--gen_start_scale', '1', '--scale_factor', '0.75'
])
opt = functions.post_config(opt)
dirlab = ",sf_0.75"
for opt.gen_start_scale in range(0, 9):
    Gs = []
    Zs = []
    reals = []
    NoiseAmp = []
    dir2orig = functions.generate_dir2save(opt)
    dir2save = dir2orig + dirlab
    if dir2save is None:
        print('task does not exist')
    elif (os.path.exists(dir2save)):
        if opt.mode == 'random_samples':
            print(
                'random samples for image %s, start scale=%d, already exist' %
                (opt.input_name, opt.gen_start_scale))
        elif opt.mode == 'random_samples_arbitrary_sizes':
            print(
                'random samples for image %s at size: scale_h=%f, scale_v=%f, already exist'
                % (opt.input_name, opt.scale_h, opt.scale_v))
    else:
        try:
            os.makedirs(dir2orig)
                        default='Input/Images')
    parser.add_argument('--input_name',
                        help='training image name',
                        default="33039_LR.png")  #required=True)
    parser.add_argument('--sr_factor',
                        help='super resolution factor',
                        type=float,
                        default=4)
    parser.add_argument('--mode', help='task to be done', default='SR')
    opt = parser.parse_args()
    opt = functions.post_config(opt)
    Gs = []
    Zs = []
    reals = []
    NoiseAmp = []
    dir2save = functions.generate_dir2save(opt)
    if dir2save is None:
        print('task does not exist')
    #elif (os.path.exists(dir2save)):
    #    print("output already exist")
    else:
        try:
            os.makedirs(dir2save)
        except OSError:
            pass

        mode = opt.mode
        in_scale, iter_num = functions.calc_init_scale(opt)
        opt.scale_factor = 1 / in_scale
        opt.scale_factor_init = 1 / in_scale
        opt.mode = 'train'
Ejemplo n.º 26
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    #from the name get the picture
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    #scale1 is defined from adjust2scale, saved in opt
    real = imresize(real_, opt.scale1, opt)
    # a list of resized images
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    #for scale 0 to stop scale
    for scale_num in tqdm_notebook(range(opt.stop_scale + 1),
                                   desc=opt.input_name,
                                   leave=True):

        #define the number of channels in this scale
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)

        #define the minimum number of channels in this scale
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        #the output main directory
        opt.out_ = functions.generate_dir2save(opt)

        #the output sub directory for each scale
        opt.outf = f'{opt.out_}/{scale_num}'

        #if need create the directory
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #save the resized original image for this scale
        plt.imsave(f'{opt.outf}/real_scale.png',
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        #create the generator and discriminator
        D_curr, G_curr = init_models(opt)

        #if the number of channel of previous layer = current nfc
        if (nfc_prev == opt.nfc):

            #direct load the weightfrom last model
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        #train a single scale, get the current z, in_s, generator
        z_curr, in_s, G_curr = train_single_scale(
            D_curr,  #current discriminator
            G_curr,  #current generator
            reals,  #the list of all resized data
            Gs,  #generator list
            Zs,  # a list initialized as []
            in_s,  #
            NoiseAmp,  #
            opt  #parameters
        )

        #make current G and D untrainable,set it into eval mode
        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        # save them into the list
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        #save the checkpoints
        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        nfc_prev = opt.nfc
        #delete the D and G for memory
        del D_curr, G_curr
    return
Ejemplo n.º 27
0
import SinGAN.functions as functions

if __name__ == '__main__':
    parser = get_arguments()
    parser.add_argument('--input_dir',
                        help='input image dir',
                        default='Input/Images')
    parser.add_argument('--model_name',
                        help='input image name -1',
                        required=True)
    parser.add_argument('--mode', help='task to be done', default='train')
    opt = parser.parse_args()
    opt = functions.post_config(opt)
    Gs = []
    Zs = []
    reals = []
    NoiseAmp = []
    dir2save = functions.generate_dir2save(opt)

    if (os.path.exists(dir2save)):
        print('trained model already exist')
    else:
        try:
            os.makedirs(dir2save)
        except OSError:
            pass
        real = functions.read_images(opt)
        functions.adjust_scales2image(real, opt)
        train(opt, Gs, Zs, reals, NoiseAmp)
        SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt)
Ejemplo n.º 28
0
def train(opt,Gs,Zs,reals,NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_,opt.scale1,opt)
    print ("real.shape:",real.shape)  #real.shape: torch.Size([1, 3, 186, 248])
    reals = functions.creat_reals_pyramid(real,reals,opt)
    for i in reals: print (i.shape)
    '''
    torch.Size([1, 3, 20, 27])
    torch.Size([1, 3, 27, 36])
    torch.Size([1, 3, 35, 47])
    torch.Size([1, 3, 47, 62])
    torch.Size([1, 3, 61, 82])
    torch.Size([1, 3, 81, 108])
    torch.Size([1, 3, 107, 143])
    torch.Size([1, 3, 141, 188])
    torch.Size([1, 3, 186, 248])
    '''
    nfc_prev = 0
    print ("total %d scales.."% (opt.stop_scale))
    while scale_num<opt.stop_scale+1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        if opt.fast_training:
            if (scale_num > 0) & (scale_num % 4==0):
                opt.niter = opt.niter//2

        '''
        if (scale_num == opt.stop_scale):
            opt.nfc = 128
            opt.min_nfc = 128
        '''
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,scale_num)
        print ("out dir:",opt.outf)
        try:
            os.makedirs(opt.outf)
        except OSError:
                pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' %  (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1)
        #生成D和G,由于是全卷积网络,不需要关心输入大小
        D_curr,G_curr = init_models(opt)
        if (nfc_prev==opt.nfc):  #如果两次网络中的层数一样,则可以finetune上一个scale的网络参数
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,scale_num-1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1)))
            
        #对该scale的网络进行训练,这里每训练一个scale的网络就换,保持显存占用一直很低,不到3g
        z_curr,in_s,G_curr = train_single_scale(D_curr,G_curr,reals,Gs,Zs,in_s,NoiseAmp,opt)

        G_curr = functions.reset_grads(G_curr,False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr,False)
        D_curr.eval()
        
        #保留训练后的G网络
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num+=1
        nfc_prev = opt.nfc
        
        #这里将D和G网络删除,减小显存?
        del D_curr,G_curr
    return
Ejemplo n.º 29
0
def train(opt, Gs, Zs, reals, crops, masks, NoiseAmp):
    real_ = functions.read_image(opt)
    real = imresize(real_, opt.scale1, opt)

    #real, _ , _ = functions.random_crop(real, opt.crop_size)
    mask_ = functions.read_mask(opt)
    #eye_ = functions.generate_eye_mask(opt, mask_, 0)
    crop_ = torch.zeros(
        (1, 1, opt.crop_size,
         opt.crop_size))  #Used just for size reference when downsizing
    #eye_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
    eye_color = functions.get_eye_color(real)
    opt.eye_color = eye_color
    #torch.autograd.set_detect_anomaly(True)

    in_s = 0
    scale_num = 0

    reals = functions.create_pyramid(real, reals, opt)
    masks = functions.create_pyramid(mask_, masks, opt, mode="mask")
    #eyes = functions.create_pyramid(eye_,eyes,opt, mode = "mask")

    # Shortcut to get sizes of corresponding crops for each scale
    crops = functions.create_pyramid(crop_, crops, opt, mode="mask")

    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, crops,
                                                  masks, eye_color, Gs, Zs,
                                                  in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr

    return
Ejemplo n.º 30
0
def SinGAN_generate(Gs,
                    Zs,
                    reals,
                    NoiseAmp,
                    opt,
                    in_s=None,
                    scale_v=1,
                    scale_h=1,
                    n=0,
                    gen_start_scale=0,
                    num_samples=50,
                    output_image=False):
    #if torch.is_tensor(in_s) == False:
    if in_s is None:
        # make in_s a 0 tensor with reals[0] shape
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    images_cur = []
    #for each layers
    for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp):
        #generate a pad class with width ((ker_size-1)*num_layer)/2
        pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2
        m = nn.ZeroPad2d(int(pad1))

        #the shape inside padding * scale
        nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v
        nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h

        #get all the previsous image
        images_prev = images_cur
        images_cur = []
        output_list = []
        #for the number of samples
        for i in range(0, num_samples, 1):
            if n == 0:
                #generate the noise
                z_curr = functions.generate_noise([1, nzx, nzy],
                                                  device=opt.device)
                #broadcast to the correct shape
                z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3])
                #padding it
                z_curr = m(z_curr)
            else:
                #generate noise with defined shape
                z_curr = functions.generate_noise([opt.nc_z, nzx, nzy],
                                                  device=opt.device)
                #padding
                z_curr = m(z_curr)
            #if it's the first scale
            if images_prev == []:
                #use in_s as the first one
                I_prev = m(in_s)
                #I_prev = m(I_prev)
                #I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                #I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
            else:
                #get the last image
                I_prev = images_prev[i]
                #resize it by 1/scale_factor
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)
                # cut a piece of shape (round(scale_v * reals[n].shape[2] * round(scale_h * reals[n].shape[3]))
                I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]),
                                0:round(scale_h * reals[n].shape[3])]
                #padding
                I_prev = m(I_prev)
                #cut a piece of shape (z_curr.shape[2], z_curr.shape[3])
                I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]]
                #upsample this piece to original shape, with bilinear policy
                I_prev = functions.upsampling(I_prev, z_curr.shape[2],
                                              z_curr.shape[3])

            # amplify the z by the param, add the previous graph
            z_in = noise_amp * (z_curr) + I_prev

            # pass this value and previous graph to generator, get the value
            I_curr = G(z_in.detach(), I_prev)

            #for the last loop
            if n == len(reals) - 1:
                #generate the directory
                dir2save = functions.generate_dir2save(opt)  #modified
                try:
                    os.makedirs(dir2save)
                except OSError:
                    pass
                # new variable
                if (output_image):
                    #save the new generated image
                    plt.imsave(f'{dir2save}/{i}.png',
                               functions.convert_image_np(I_curr.detach()),
                               vmin=0,
                               vmax=1)
                # have the generated image into the list
                output_list.append(functions.convert_image_np(I_curr.detach()))
            images_cur.append(I_curr)
        n += 1
    return I_curr.detach(), output_list  #newly added