def TuiGAN_transfer(Gs, Zs, reals, NoiseAmp, Gs2, opt, in_s=None, gen_start_scale=0): if in_s is None: in_s = torch.full(reals[0].shape, 0, device=opt.device) x_ab = in_s x_aba = in_s count = 0 dir2save = functions.generate_dir2save(opt) try: os.makedirs(dir2save) except OSError: pass for G, G2, Z_opt, real_curr, real_next, noise_amp in zip( Gs, Gs2, Zs, reals, reals[1:], NoiseAmp): z = functions.generate_noise([3, Z_opt.shape[2], Z_opt.shape[3]], device=opt.device) z = z.expand(real_curr.shape[0], 3, z.shape[2], z.shape[3]) x_ab = x_ab[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]] z_in = noise_amp * z + real_curr x_ab = G(z_in.detach(), x_ab) x_aba = G2(x_ab, x_aba) x_ab = imresize(x_ab.detach(), 1 / opt.scale_factor, opt) x_ab = x_ab[:, :, 0:real_next.shape[2], 0:real_next.shape[3]] x_aba = imresize(x_aba.detach(), 1 / opt.scale_factor, opt) x_aba = x_aba[:, :, 0:real_next.shape[2], 0:real_next.shape[3]] count += 1 plt.imsave('%s/x_ab_%d.png' % (dir2save, count), functions.convert_image_np(x_ab.detach()), vmin=0, vmax=1) plt.imsave('%s.png' % (dir2save), functions.convert_image_np(x_ab.detach()), vmin=0, vmax=1) # plt.imsave('%s.jpg' % (dir2save), functions.convert_image_np(x_ab.detach()), vmin=0,vmax=1) return x_ab.detach()
def train(opt,Gs,Zs,reals,NoiseAmp, Gs2,Zs2,reals2,NoiseAmp2): real_, real_2 = functions.read_two_domains(opt) in_s = 0 in_s2 = 0 scale_num = 0 real = imresize(real_,opt.scale1,opt) real2 = imresize(real_2,opt.scale1,opt) reals = functions.creat_reals_pyramid(real,reals,opt) reals2 = functions.creat_reals_pyramid(real2,reals2,opt) nfc_prev = 0 errD_plot = [] errD2_plot = [] errG_plot = [] errG2_plot = [] rec_loss_plot = [] rec_loss2_plot = [] cyc_loss_plot = [] cyc_loss2_plot = [] while scale_num<opt.stop_scale+1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,scale_num) try: os.makedirs(opt.outf) except OSError: pass D_curr,G_curr, D_curr2,G_curr2 = init_models(opt) if (nfc_prev==opt.nfc): G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,scale_num-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1))) G_curr2.load_state_dict(torch.load('%s/%d/netG2.pth' % (opt.out_,scale_num-1))) D_curr2.load_state_dict(torch.load('%s/%d/netD2.pth' % (opt.out_,scale_num-1))) z_curr,in_s,G_curr, z_curr2,in_s2,G_curr2 = train_single_scale(D_curr,G_curr, reals,Gs,Zs,in_s,NoiseAmp, errD_plot,errG_plot,rec_loss_plot,cyc_loss_plot, D_curr2,G_curr2, reals2,Gs2,Zs2,in_s2,NoiseAmp2, errD2_plot,errG2_plot,rec_loss2_plot,cyc_loss2_plot, opt,scale_num) G_curr = functions.reset_grads(G_curr,False) G_curr.eval() D_curr = functions.reset_grads(D_curr,False) D_curr.eval() G_curr2 = functions.reset_grads(G_curr2,False) G_curr2.eval() D_curr2 = functions.reset_grads(D_curr2,False) D_curr2.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) Gs2.append(G_curr2) Zs2.append(z_curr2) NoiseAmp2.append(opt.noise_amp2) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) torch.save(Zs2, '%s/Zs2.pth' % (opt.out_)) torch.save(Gs2, '%s/Gs2.pth' % (opt.out_)) torch.save(reals2, '%s/reals2.pth' % (opt.out_)) torch.save(NoiseAmp2, '%s/NoiseAmp2.pth' % (opt.out_)) scale_num+=1 nfc_prev = opt.nfc del D_curr,G_curr, D_curr2,G_curr2 functions.my_plot(errD_plot,errG_plot,rec_loss_plot,cyc_loss_plot,opt) functions.my_plot2(errD2_plot,errG2_plot,rec_loss2_plot,cyc_loss2_plot,opt) return
) parser.add_argument('--input_dir', help='input image dir', required=True) parser.add_argument('--input_name', help='input image name', required=True) parser.add_argument('--mode', help='task to be done', default='transfer') parser.add_argument('--start_scale', help='injection scale', type=int, default='0') opt = parser.parse_args() opt = functions.post_config(opt) Gs = [] Zs = [] reals = [] NoiseAmp = [] Gs2 = [] dir2save = functions.generate_dir2save(opt) if dir2save is None: print('task does not exist') else: try: os.makedirs(dir2save) except OSError: pass real_in = functions.read_image(opt) functions.adjust_scales2image(real_in, opt) real_ = functions.read_image(opt) real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) Gs, Zs, NoiseAmp, Gs2 = functions.load_model(opt)
def train(opt): print("Training model with the following parameters:") print("\t number of stages: {}".format(opt.train_stages)) print("\t number of concurrently trained stages: {}".format( opt.train_depth)) print("\t learning rate scaling: {}".format(opt.lr_scale)) print("\t non-linearity: {}".format(opt.activation)) # 加载数据集 train_loader = DataLoader(datasets.NWPU(opt.train_images, opt), batch_size=opt.batch_size, shuffle=True, num_workers=2) val_loader = DataLoader(datasets.NWPU(opt.val_images, opt), batch_size=1, shuffle=True, num_workers=2) test_loader = DataLoader(datasets.NWPU(opt.test_images, opt), batch_size=opt.batch_size, shuffle=False, num_workers=2) temp, _ = next(iter(train_loader)) shapes = [temp[i].shape for i in range(len(temp))] print("Training on image pyramid: {}".format(shapes)) del temp generator = init_G(opt) noise_amp = [] # for scale_num in range(opt.stop_scale + 1): for scale_num in range(opt.stop_scale): opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) opt.logs_out = opt.out_ + '/logs' try: os.makedirs(opt.outf) os.makedirs(opt.logs_out) except OSError: print(OSError) pass d_curr = init_D(opt) if scale_num > 0: d_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) # generator = generator.module generator.init_next_stage() writer = SummaryWriter(log_dir=opt.logs_out) noise_amp, generator, d_curr = train_single_scale( d_curr, generator, shapes, train_loader, val_loader, test_loader, noise_amp, opt, scale_num, writer) # torch.save(fixed_noise, '%s/fixed_noise.pth' % opt.out_) torch.save(generator, '%s/G.pth' % opt.out_) # torch.save(reals, '%s/reals.pth' % opt.out_) torch.save(noise_amp, '%s/noise_amp.pth' % opt.out_) del d_curr writer.close() return