def main(): # Split the dataset train_dataset = sunnerData.ImageDataset( root=[['/home/sunner/Music/waiting_for_you_dataset/wait'], ['/home/sunner/Music/waiting_for_you_dataset/real_world']], transform=None, split_ratio=0.1, save_file=True) del train_dataset test_dataset = sunnerData.ImageDataset( file_name='.split.pkl', transform=transforms.Compose([ sunnertransforms.Resize((160, 320)), sunnertransforms.ToTensor(), sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW), sunnertransforms.Normalize(), ])) # Create the data loader loader = sunnerData.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2) # Use upper wrapper to assign particular iteration loader = sunnerData.IterationLoader(loader, max_iter=1) # Show! for batch_img, _ in loader: batch_img = sunnertransforms.asImg(batch_img, size=(160, 320)) cv2.imshow('show_window', batch_img[0][:, :, ::-1]) cv2.waitKey(0)
def main(): # Get the pallete object pallete = sunnertransforms.getCategoricalMapping( path="ear-pen-pallete.json")[0] # Create the dataset img_dataset = sunnerData.ImageDataset( root=[ ['/home/sunner/Music/Ear-Pen-master/generate/train/img'], ], transforms=transforms.Compose([ sunnertransforms.Resize((260, 195)), sunnertransforms.ToTensor(), sunnertransforms.ToFloat(), sunnertransforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ])) tag_dataset = sunnerData.ImageDataset( root=[['/home/sunner/Music/Ear-Pen-master/generate/train/tag']], transforms=transforms.Compose([ sunnertransforms.Resize((260, 195)), sunnertransforms.ToTensor(), sunnertransforms.ToFloat(), sunnertransforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), sunnertransforms.CategoricalTranspose( pallete=pallete, direction=sunnertransforms.COLOR2INDEX, index_default=0) ])) # Create the loader loader = sunnerData.MultiLoader([img_dataset, tag_dataset], batch_size=1, shuffle=False, num_workers=2) # Define the reverse operator back_op = sunnertransforms.CategoricalTranspose( pallete=pallete, direction=sunnertransforms.INDEX2COLOR, index_default=0) # Show! for (_, batch_tag) in loader: batch_tag = back_op(batch_tag) batch_tag = sunnertransforms.asImg(batch_tag, size=(260, 195)) cv2.imshow('show_window', batch_tag[0][:, :, ::-1]) cv2.waitKey(0) break
def main(): # Create the fundemental data loader loader = sunnerData.DataLoader(sunnerData.ImageDataset( root=[['/home/sunner/Music/waiting_for_you_dataset/wait'], ['/home/sunner/Music/waiting_for_you_dataset/real_world']], transforms=transforms.Compose([ sunnertransforms.Resize((160, 320)), sunnertransforms.ToTensor(), sunnertransforms.ToFloat(), sunnertransforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ])), batch_size=32, shuffle=False, num_workers=2) # Use upper wrapper to assign particular iteration loader = sunnerData.IterationLoader(loader, max_iter=1) # Show! for batch_tensor, _ in loader: batch_img = sunnertransforms.asImg(batch_tensor, size=(160, 320)) cv2.imshow('show_window', batch_img[0][:, :, ::-1]) cv2.waitKey(0) # Or show multiple image in one line sunnertransforms.show(batch_tensor[:10], row=2, column=5)
def demo(args): """ This function define the demo process Arg: args (napmespace) - The arguments """ # Create the data loader loader = sunnerData.DataLoader( dataset=sunnerData.ImageDataset( root=[[args.demo]], transforms=transforms.Compose([ sunnerTransforms.Resize(output_size=(args.H, args.W)), sunnerTransforms.ToTensor(), sunnerTransforms.ToFloat(), # sunnerTransforms.Transpose(), sunnerTransforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ])), batch_size=args.batch_size, shuffle=True, num_workers=2) # Create the model model = GANomaly2D(r=args.r, device=args.device) model.IO(args.resume, direction='load') # Demo! bar = tqdm(loader) model.eval() with torch.no_grad(): for (img, ) in bar: z, z_ = model.forward(img) img, img_ = model.getImg() visualizeAnomalyImage(img, img_, z, z_)
def main(args): """ Train the 2nd step for LaDo Arg: args (Namespace) - The argument object """ # Create the data loader and the pre-trained model loader = Data.DataLoader(dataset=sunnerData.ImageDataset( root=[[args.src_pair_image_folder], [args.tar_pair_image_folder]], transform=transforms.Compose([ transforms.Resize((args.img_size, args.img_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]), sample_method=sunnerData.OVER_SAMPLING), batch_size=args.batch_size, shuffle=True, num_workers=8) model = LaDo(args.content_dims, args.appearance_dims, args.batch_size) if os.path.exists(args.model_path_1st): model.load(args.model_path_1st, stage=1) model.copy() model.fix() else: raise Exception( "Pre-trained model didn't exist, please train for the 1st step first!" ) # Loop for ep in range(args.total_epoch + 1): bar = tqdm(loader) for i, (src_img, tar_img) in enumerate(bar): model.setInput_2nd(src_img, tar_img) if ep != 0: model.backward_2nd() # Print average loss if ep != 0: print("=====" * 20) print("<< Epoch {} average >>".format(ep) + " " * 20) string = "" for i, (key, loss) in enumerate( model.getLoss(stage=2, normalize=True).items()): if loss == 0.0: continue loss = round(loss, 10) if i % 2 == 1: string = "{:>20}: {:>15} \t".format(key[14:], str(loss)) else: string += "{:>20}: {:>15} \t".format(key[14:], str(loss)) print(string) model.finishEpoch() # Save if not os.path.exists(os.path.join(args.root_folder, 'models_2nd')): os.mkdir(os.path.join(args.root_folder, 'models_2nd')) model.save(path=os.path.join(args.root_folder, 'models_2nd', num2Str(ep) + '.pth'))
def main(): # Define the loader to generate the pallete object loader = sunnerData.DataLoader( sunnerData.ImageDataset( root=[tag_folder], transform=transforms.Compose([ sunnertransforms.ToTensor(), ]), save_file=False # Don't save the record file, be careful! ), batch_size=2, shuffle=False, num_workers=2) pallete = sunnertransforms.getCategoricalMapping(loader, path='pallete.json')[0] del loader # Define the actual loader loader = sunnerData.DataLoader(sunnerData.ImageDataset( root=[img_folder, tag_folder], transform=transforms.Compose([ sunnertransforms.Resize((512, 1024)), sunnertransforms.ToTensor(), sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW), sunnertransforms.Normalize(), ])), batch_size=32, shuffle=False, num_workers=2) # Define the reverse operator goto_op = sunnertransforms.CategoricalTranspose( pallete=pallete, direction=sunnertransforms.COLOR2ONEHOT) back_op = sunnertransforms.CategoricalTranspose( pallete=pallete, direction=sunnertransforms.ONEHOT2COLOR) # Show! for _, batch_index in loader: batch_img = back_op(goto_op(batch_index)) batch_img = sunnertransforms.asImg(batch_img, size=(512, 1024)) cv2.imshow('show_window', batch_img[0][:, :, ::-1]) cv2.waitKey(0) break
def main(): # Define the loader to generate the pallete object loader = sunnerData.DataLoader(sunnerData.ImageDataset( root = [ tag_folder ], transforms = transforms.Compose([ sunnertransforms.ToTensor(), sunnertransforms.ToFloat(), sunnertransforms.UnNormalize(mean=[0, 0, 0], std=[255, 255, 255]), # Remember to transfer back to [0~255] before generate pallete sunnertransforms.Transpose(sunnertransforms.BCHW2BHWC) # Remember to transfer back to BHWC before generate pallete ]) ), batch_size = 2, shuffle = False, num_workers = 2 ) pallete = sunnertransforms.getCategoricalMapping(loader, path = 'pallete.json')[0] del loader # Define the actual loader loader = sunnerData.DataLoader(sunnerData.ImageDataset( root = [ img_folder, tag_folder ], transforms = transforms.Compose([ sunnertransforms.Resize((512, 1024)), sunnertransforms.ToTensor(), sunnertransforms.ToFloat(), sunnertransforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5]), ])), batch_size = 32, shuffle = False, num_workers = 2 ) # Define the reverse operator goto_op = sunnertransforms.CategoricalTranspose(pallete = pallete, direction = sunnertransforms.COLOR2ONEHOT) back_op = sunnertransforms.CategoricalTranspose(pallete = pallete, direction = sunnertransforms.ONEHOT2COLOR) # Show! for _, batch_index in loader: batch_img = back_op(goto_op(batch_index)) batch_img = sunnertransforms.asImg(batch_img, size = (512, 1024)) cv2.imshow('show_window', batch_img[0][:, :, ::-1]) cv2.waitKey(0) break
def train(args): """ This function define the training process Arg: args (napmespace) - The arguments """ # Create the data loader loader = sunnerData.DataLoader( dataset=sunnerData.ImageDataset( root=[[args.train]], transforms=transforms.Compose([ # transforms.RandomCrop(720,720) # transforms.RandomRotation(45) # transforms.RandomHorizontalFlip(), # transforms.ColorJitter(brightness=0.5, contrast=0.5), sunnerTransforms.Resize(output_size=(args.H, args.W)), #transforms.RandomCrop(512,512) sunnerTransforms.ToTensor(), sunnerTransforms.ToFloat(), # sunnerTransforms.Transpose(), sunnerTransforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ])), batch_size=args.batch_size, shuffle=True, num_workers=2) loader = sunnerData.IterationLoader(loader, max_iter=args.n_iter) # Create the model model = GANomaly2D(r=args.r, device=args.device) model.IO(args.resume, direction='load') model.train() # Train! bar = tqdm(loader) for i, (normal_img, ) in enumerate(bar): model.forward(normal_img) model.backward() loss_G, loss_D = model.getLoss() bar.set_description("Loss_G: " + str(loss_G) + " loss_D: " + str(loss_D)) bar.refresh() if i % args.record_iter == 0: model.eval() with torch.no_grad(): z, z_ = model.forward(normal_img) img, img_ = model.getImg() visualizeEncoderDecoder(img, img_, z, z_, i) model.train() model.IO(args.det, direction='save') model.IO(args.det, direction='save')
def main(): # Create the fundemental data loader loader = sunnerData.DataLoader(sunnerData.ImageDataset( root=[['/home/sunner/Music/waiting_for_you_dataset/wait'], ['/home/sunner/Music/waiting_for_you_dataset/real_world']], transform=transforms.Compose([ sunnertransforms.Resize((160, 320)), sunnertransforms.ToTensor(), sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW), sunnertransforms.Normalize(), ])), batch_size=32, shuffle=False, num_workers=2) # Use upper wrapper to assign particular iteration loader = sunnerData.IterationLoader(loader, max_iter=1) # Show! for batch_img, _ in loader: batch_img = sunnertransforms.asImg(batch_img, size=(160, 320)) cv2.imshow('show_window', batch_img[0][:, :, ::-1]) cv2.waitKey(0)
def evalModel(args, model): # Create data loader loader = sunnerData.ImageLoader(sunnerData.ImageDataset( root_list=[args.folder_path, args.mask_path], transform=transforms.Compose([ sunnertransforms.Rescale((args.size, args.size)), sunnertransforms.ToTensor(), sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW), sunnertransforms.Normalize() ]), sample_method=sunnerData.OVER_SAMPLING), batch_size=1, shuffle=False, num_workers=2) # Compute the PSNR and record psnr_list = [] bar = tqdm(loader) for image, mask in bar: # Double the tensor to adapt with BN image = torch.cat([image, image], 0) mask = torch.cat([mask, mask], 0) # forward mask = (mask + 1) / 2 model.setInput(target=image, mask=mask) model.forward() _, recon_img, _ = model.getOutput() psnr = compare_psnr(image[0].detach().cpu().numpy(), recon_img[0].detach().cpu().numpy()) psnr_list.append(psnr) # Show the result print('\n\n') print('-' * 20, 'Complete evaluation', '-' * 20) print('Testing average psnr: %.4f' % np.mean(psnr_list))
def main(opts): # Create the data loader loader = sunnerData.DataLoader( sunnerData.ImageDataset(root=[[opts.path]], transform=transforms.Compose([ sunnertransforms.Resize((1024, 1024)), sunnertransforms.ToTensor(), sunnertransforms.ToFloat(), sunnertransforms.Transpose( sunnertransforms.BHWC2BCHW), sunnertransforms.Normalize(), ])), batch_size=opts.batch_size, shuffle=True, ) # Create the model start_epoch = 0 G = StyleGenerator() D = StyleDiscriminator() # Load the pre-trained weight if os.path.exists(opts.resume): INFO("Load the pre-trained weight!") state = torch.load(opts.resume) G.load_state_dict(state['G']) D.load_state_dict(state['D']) start_epoch = state['start_epoch'] else: INFO( "Pre-trained weight cannot load successfully, train from scratch!") # Multi-GPU support if torch.cuda.device_count() > 1: INFO("Multiple GPU:" + str(torch.cuda.device_count()) + "\t GPUs") G = nn.DataParallel(G) D = nn.DataParallel(D) G.to(opts.device) D.to(opts.device) # Create the criterion, optimizer and scheduler optim_D = optim.Adam(D.parameters(), lr=0.00001, betas=(0.5, 0.999)) optim_G = optim.Adam(G.parameters(), lr=0.00001, betas=(0.5, 0.999)) scheduler_D = optim.lr_scheduler.ExponentialLR(optim_D, gamma=0.99) scheduler_G = optim.lr_scheduler.ExponentialLR(optim_G, gamma=0.99) # Train fix_z = torch.randn([opts.batch_size, 512]).to(opts.device) softplus = nn.Softplus() Loss_D_list = [0.0] Loss_G_list = [0.0] for ep in range(start_epoch, opts.epoch): bar = tqdm(loader) loss_D_list = [] loss_G_list = [] for i, (real_img, ) in enumerate(bar): # ======================================================================================================= # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) # ======================================================================================================= # Compute adversarial loss toward discriminator D.zero_grad() real_img = real_img.to(opts.device) real_logit = D(real_img) fake_img = G(torch.randn([real_img.size(0), 512]).to(opts.device)) fake_logit = D(fake_img.detach()) d_loss = softplus(fake_logit).mean() d_loss = d_loss + softplus(-real_logit).mean() if opts.r1_gamma != 0.0: r1_penalty = R1Penalty(real_img.detach(), D) d_loss = d_loss + r1_penalty * (opts.r1_gamma * 0.5) if opts.r2_gamma != 0.0: r2_penalty = R2Penalty(fake_img.detach(), D) d_loss = d_loss + r2_penalty * (opts.r2_gamma * 0.5) loss_D_list.append(d_loss.item()) # Update discriminator d_loss.backward() optim_D.step() # ======================================================================================================= # (2) Update G network: maximize log(D(G(z))) # ======================================================================================================= if i % CRITIC_ITER == 0: G.zero_grad() fake_logit = D(fake_img) g_loss = softplus(-fake_logit).mean() loss_G_list.append(g_loss.item()) # Update generator g_loss.backward() optim_G.step() # Output training stats bar.set_description("Epoch {} [{}, {}] [G]: {} [D]: {}".format( ep, i + 1, len(loader), loss_G_list[-1], loss_D_list[-1])) # Save the result Loss_G_list.append(np.mean(loss_G_list)) Loss_D_list.append(np.mean(loss_D_list)) # Check how the generator is doing by saving G's output on fixed_noise with torch.no_grad(): fake_img = G(fix_z).detach().cpu() save_image(fake_img, os.path.join(opts.det, 'images', str(ep) + '.png'), nrow=4, normalize=True) # Save model state = { 'G': G.state_dict(), 'D': D.state_dict(), 'Loss_G': Loss_G_list, 'Loss_D': Loss_D_list, 'start_epoch': ep, } torch.save(state, os.path.join(opts.det, 'models', 'latest.pth')) scheduler_D.step() scheduler_G.step() # Plot the total loss curve Loss_D_list = Loss_D_list[1:] Loss_G_list = Loss_G_list[1:] plotLossCurve(opts, Loss_D_list, Loss_G_list)
def train(opts): def log(string, name="stylegan.log"): with open(name, 'a') as f: f.write(string + '\n') writer = SummaryWriter(str(opts.output)) loader = dataset.DataLoader(dataset=dataset.ImageDataset( [[opts.input]], transform=transforms.Compose([ trans.Resize((opts.imsize, opts.imsize)), trans.ToTensor(), trans.ToFloat(), trans.Transpose(trans.BHWC2BCHW), trans.Normalize() ])), batch_size=opts.batch_size, shuffle=True) G = opts.G D = opts.D step = 0 start_epoch = opts.start_epoch if opts.resume: try: assert os.path.exists(opts.resume) state = torch.load(opts.resume) G.load_state_dict(state['G']) D.load_state_dict(state['D']) start_epoch = state['start_epoch'] logger.info("Load Pretrained Weight") except: logger.warn("Resume Files cannot Load") logger.info("Train from Scratch") else: logger.info("Train from Scratch") if torch.cuda.device_count() > 1 and opts.device == 'cuda': logger.info(f"{torch.cuda.device_count()} GPUs found.") G = nn.DataParallel(G) D = nn.DataParallel(D) if opts.device == 'cuda': torch.backends.cudnn.benchmark = True G.to(opts.device) D.to(opts.device) optimG = Adam(G.parameters(), lr=opts.g_lr, betas=opts.betas) optimD = Adam(D.parameters(), lr=opts.d_lr, betas=opts.betas) schedulerG = lr_scheduler.ExponentialLR(optimG, gamma=opts.g_lrdecay) schedulerD = lr_scheduler.ExponentialLR(optimD, gamma=opts.d_lrdecay) fixed_z = torch.randn([opts.batch_size, 512]).to(opts.device) sp = nn.Softplus() g_store = [0.0] d_store = [0.0] for epoch in range(start_epoch, opts.epochs + 1): bar = tqdm(loader) glosses = [] dlosses = [] for i, (real, ) in enumerate(bar): step += 1 D.zero_grad() real = real.to(opts.device) Dr = D(real) writer.add_graph(D, real) z = torch.randn([real.size(0), 512]).to(opts.device) fake = G(z) writer.add_graph(G, z) Df = D(fake.detach()) Dloss = sp(Df).mean() + sp(-Dr).mean() if opts.r1gamma > 0: r1 = r1_penalty(real.detach(), D) Dloss = Dloss + r1 * (opts.r1gamma * .5) if opts.r2gamma > 0: r2 = r2_penalty(fake.detach(), D) Dloss = Dloss + r2 * (opts.r2gamma * .5) dlosses.append(Dloss.item()) Dloss.backward() optimD.step() if i % opts.critic_iters == 0: G.zero_grad() Df = D(fake) Gloss = sp(-Df).mean() glosses.append(Gloss.item()) Gloss.backward() optimG.step() if i % opts.show_interval == 0: with torch.no_grad(): nr = int(math.ceil(math.sqrt(opts.batch_size))) z = torch.randn([real.size(0), 512]).to(opts.device) img = G(z) save_image(img.detach().cpu(), os.path.join(opts.output, 'images', 'normal', f'{epoch:04}_{i:06}.png'), nrow=nr, normalize=True) fakes = utils.make_grid(img, nr, padding=0) fakes = fakes.to(torch.float32).cpu().numpy() fakes = np.clip((fakes / 2) + 0.5, 0, 1) writer.add_image(f"EPOCH{epoch}/Random", torch.from_numpy(fakes), i) img = G(fixed_z) save_image(img.detach().cpu(), os.path.join(opts.output, 'images', 'fixed', f'{epoch:04}_{i:06}.png'), nrow=nr, normalize=True) fakes = utils.make_grid(img, nr, padding=0) fakes = fakes.to(torch.float32).cpu().numpy() fakes = np.clip((fakes / 2) + 0.5, 0, 1) writer.add_image(f"EPOCH{epoch}/Fixed", torch.from_numpy(fakes), i) writer.add_scalar(f"LOSS/Generator", Gloss.item(), global_step=step) writer.add_scalar(f"LOSS/Discriminator", Dloss.item(), global_step=step) bar.set_description( f"Epoch {epoch}/{opts.epochs} G: {glosses[-1]:.6f} D: {dlosses[-1]:.6f}" ) g_store.append(np.mean(glosses)) d_store.append(np.mean(dlosses)) state = { 'G': G.state_dict(), 'D': D.state_dict(), 'Loss_G': g_store, 'Loss_D': d_store, 'start_epoch': epoch, 'opts': opts } torch.save(state, os.path.join(opts.output, 'models', 'latest.pth')) if epoch % 10 == 0: torch.save(state, os.path.join(opts.output, 'models', f'{epoch:04}.pth')) schedulerD.step() schedulerG.step()
def main(opts): # Create the data loader loader = sunnerData.DataLoader(sunnerData.ImageDataset( root=[[opts.path]], transform=transforms.Compose([ sunnertransforms.Resize((128, 128)), sunnertransforms.ToTensor(), sunnertransforms.ToFloat(), sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW), sunnertransforms.Normalize(), ])), batch_size=opts.batch_size, shuffle=True, num_workers=4) # Create the model if opts.type == 'style': G = StyleGenerator().to(opts.device) else: G = Generator().to(opts.device) D = Discriminator().to(opts.device) # Load the pre-trained weight if os.path.exists(opts.resume): INFO("Load the pre-trained weight!") state = torch.load(opts.resume) G.load_state_dict(state['G']) D.load_state_dict(state['D']) else: INFO( "Pre-trained weight cannot load successfully, train from scratch!") # Create the criterion, optimizer and scheduler optim_D = optim.Adam(D.parameters(), lr=0.0001, betas=(0.5, 0.999)) optim_G = optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.999)) scheduler_D = optim.lr_scheduler.ExponentialLR(optim_D, gamma=0.99) scheduler_G = optim.lr_scheduler.ExponentialLR(optim_G, gamma=0.99) # Train fix_z = torch.randn([opts.batch_size, 512]).to(opts.device) Loss_D_list = [0.0] Loss_G_list = [0.0] for ep in range(opts.epoch): bar = tqdm(loader) loss_D_list = [] loss_G_list = [] for i, (real_img, ) in enumerate(bar): # ======================================================================================================= # Update discriminator # ======================================================================================================= # Compute adversarial loss toward discriminator real_img = real_img.to(opts.device) real_logit = D(real_img) fake_img = G(torch.randn([real_img.size(0), 512]).to(opts.device)) fake_logit = D(fake_img.detach()) d_loss = -(real_logit.mean() - fake_logit.mean()) + gradient_penalty( real_img.data, fake_img.data, D) * 10.0 loss_D_list.append(d_loss.item()) # Update discriminator optim_D.zero_grad() d_loss.backward() optim_D.step() # ======================================================================================================= # Update generator # ======================================================================================================= if i % CRITIC_ITER == 0: # Compute adversarial loss toward generator fake_img = G( torch.randn([opts.batch_size, 512]).to(opts.device)) fake_logit = D(fake_img) g_loss = -fake_logit.mean() loss_G_list.append(g_loss.item()) # Update generator D.zero_grad() optim_G.zero_grad() g_loss.backward() optim_G.step() bar.set_description(" {} [G]: {} [D]: {}".format( ep, loss_G_list[-1], loss_D_list[-1])) # Save the result Loss_G_list.append(np.mean(loss_G_list)) Loss_D_list.append(np.mean(loss_D_list)) fake_img = G(fix_z) save_image(fake_img, os.path.join(opts.det, 'images', str(ep) + '.png'), nrow=4, normalize=True) state = { 'G': G.state_dict(), 'D': D.state_dict(), 'Loss_G': Loss_G_list, 'Loss_D': Loss_D_list, } torch.save(state, os.path.join(opts.det, 'models', 'latest.pth')) scheduler_D.step() scheduler_G.step() # Plot the total loss curve Loss_D_list = Loss_D_list[1:] Loss_G_list = Loss_G_list[1:] plotLossCurve(opts, Loss_D_list, Loss_G_list)
discriminator.module.load_state_dict(ckpt["discriminator"]) g_running.load_state_dict(ckpt["g_running"]) g_optimizer.load_state_dict(ckpt["g_optimizer"]) d_optimizer.load_state_dict(ckpt["d_optimizer"]) transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) ]) dataset = sunnerData.ImageDataset( root=[[args.path]], transform=transforms.Compose([ sunnertransforms.Resize((128, 128)), sunnertransforms.ToTensor(), sunnertransforms.ToFloat(), sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW), sunnertransforms.Normalize() ]), ) if args.sched: args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} args.batch = { 4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 32, 256: 32
def main(opts): # Create the data loader loader = sunnerData.DataLoader(sunnerData.ImageDataset( root=[[opts.path]], transform=transforms.Compose([ sunnertransforms.Resize((opts.resolution, opts.resolution)), sunnertransforms.ToTensor(), sunnertransforms.ToFloat(), sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW), sunnertransforms.Normalize(), ])), batch_size=opts.batch_size, shuffle=True, drop_last=True ) # Create the model start_epoch = 0 G = G_stylegan2(fmap_base=opts.fmap_base, resolution=opts.resolution, mapping_layers=opts.mapping_layers, opts=opts, return_dlatents=True) D = D_stylegan2(fmap_base=opts.fmap_base, resolution=opts.resolution, structure='resnet') # Load the pre-trained weight if os.path.exists(opts.resume): INFO("Load the pre-trained weight!") state = torch.load(opts.resume) G.load_state_dict(state['G']) D.load_state_dict(state['D']) start_epoch = state['start_epoch'] else: INFO("Pre-trained weight cannot load successfully, train from scratch!") # Multi-GPU support if torch.cuda.device_count() > 1: INFO("Multiple GPU:" + str(torch.cuda.device_count()) + "\t GPUs") G = torch.nn.DataParallel(G) D = torch.nn.DataParallel(D) G.to(opts.device) D.to(opts.device) # Create the criterion, optimizer and scheduler lr_D = 0.0015 lr_G = 0.0015 optim_D = torch.optim.Adam(D.parameters(), lr=lr_D, betas=(0.9, 0.999)) # g_mapping has 100x lower learning rate params_G = [{"params": G.g_synthesis.parameters()}, {"params": G.g_mapping.parameters(), "lr": lr_G * 0.01}] optim_G = torch.optim.Adam(params_G, lr=lr_G, betas=(0.9, 0.999)) scheduler_D = optim.lr_scheduler.ExponentialLR(optim_D, gamma=0.99) scheduler_G = optim.lr_scheduler.ExponentialLR(optim_G, gamma=0.99) # Train fix_z = torch.randn([opts.batch_size, 512]).to(opts.device) softplus = torch.nn.Softplus() Loss_D_list = [0.0] Loss_G_list = [0.0] for ep in range(start_epoch, opts.epoch): bar = tqdm(loader) loss_D_list = [] loss_G_list = [] for i, (real_img,) in enumerate(bar): real_img = real_img.to(opts.device) latents = torch.randn([real_img.size(0), 512]).to(opts.device) # ======================================================================================================= # (1) Update D network: D_logistic_r1(default) # ======================================================================================================= # Compute adversarial loss toward discriminator real_img = real_img.to(opts.device) real_logit = D(real_img) fake_img, fake_dlatent = G(latents) fake_logit = D(fake_img.detach()) d_loss = softplus(fake_logit) d_loss = d_loss + softplus(-real_logit) # original r1_penalty = D_logistic_r1(real_img.detach(), D) d_loss = (d_loss + r1_penalty).mean() # lite # d_loss = d_loss.mean() loss_D_list.append(d_loss.mean().item()) # Update discriminator optim_D.zero_grad() d_loss.backward() optim_D.step() # ======================================================================================================= # (2) Update G network: G_logistic_ns_pathreg(default) # ======================================================================================================= # if i % CRITIC_ITER == 0: G.zero_grad() fake_scores_out = D(fake_img) _g_loss = softplus(-fake_scores_out) # Compute |J*y|. # pl_noise = (torch.randn(fake_img.shape) / np.sqrt(fake_img.shape[2] * fake_img.shape[3])).to(fake_img.device) # pl_grads = grad(torch.sum(fake_img * pl_noise), fake_dlatent, retain_graph=True)[0] # pl_lengths = torch.sqrt(torch.sum(torch.sum(torch.mul(pl_grads, pl_grads), dim=2), dim=1)) # pl_mean = PL_DECAY * torch.sum(pl_lengths) # # pl_penalty = torch.mul(pl_lengths - pl_mean, pl_lengths - pl_mean) # reg = pl_penalty * PL_WEIGHT # # # original # g_loss = (_g_loss + reg).mean() # lite g_loss = _g_loss.mean() loss_G_list.append(g_loss.mean().item()) # Update generator g_loss.backward(retain_graph=True) optim_G.step() # Output training stats bar.set_description( "Epoch {} [{}, {}] [G]: {} [D]: {}".format(ep, i + 1, len(loader), loss_G_list[-1], loss_D_list[-1])) # Save the result Loss_G_list.append(np.mean(loss_G_list)) Loss_D_list.append(np.mean(loss_D_list)) # Check how the generator is doing by saving G's output on fixed_noise with torch.no_grad(): fake_img = G(fix_z)[0].detach().cpu() save_image(fake_img, os.path.join(opts.det, 'images', str(ep) + '.png'), nrow=4, normalize=True) # Save model state = { 'G': G.state_dict(), 'D': D.state_dict(), 'Loss_G': Loss_G_list, 'Loss_D': Loss_D_list, 'start_epoch': ep, } torch.save(state, os.path.join(opts.det, 'models', 'latest.pth')) scheduler_D.step() scheduler_G.step() # Plot the total loss curve Loss_D_list = Loss_D_list[1:] Loss_G_list = Loss_G_list[1:] plotLossCurve(opts, Loss_D_list, Loss_G_list)
return args if __name__ == '__main__': args = parse() # Create data loader sunnerData.quiet() sunnertransforms.quiet() loader = sunnerData.ImageLoader( sunnerData.ImageDataset( root_list=[args.image_folder, args.mask_folder], transform=transforms.Compose([ # sunnertransforms.Rescale((360, 640)), sunnertransforms.Rescale((256, 256)), sunnertransforms.ToTensor(), # BHWC -> BCHW sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW), sunnertransforms.Normalize() ]), sample_method=sunnerData.OVER_SAMPLING), batch_size=args.batch_size, shuffle=False, num_workers=2) # Load model if args.model_type == 'pconv': model = PartialUNet(style_list=args.style, base=64, style_weight=args.lambda_style, freeze=args.freeze)
def main(args): """ Train the 1st step for LaDo Arg: args (Namespace) - The argument object """ # Create the data loader and the model loader = Data.DataLoader(dataset=sunnerData.ImageDataset( root=[[args.src_image_folder], [args.src_pair_image_folder], [args.tar_pair_image_folder]], transform=transforms.Compose([ transforms.Resize((args.img_size, args.img_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]), sample_method=sunnerData.OVER_SAMPLING), batch_size=args.batch_size, shuffle=True, num_workers=8) model = LaDo(args.content_dims, args.appearance_dims, args.batch_size) # Loop for ep in range(args.total_epoch + 1): bar = tqdm_table(loader) bar.set_table_setting(150) for i, (src_img, src_img_pair, tar_img_pair) in enumerate(bar): if len(src_img) == 1: continue model.setInput_1st(src_img, src_img_pair, tar_img_pair) if ep != 0: # Print total loss and update model.backward_1st() bar.set_table_info( {k[14:]: v for k, v in model.getLoss(stage=1).items()}) else: model.forward_1st() break # Print average loss if ep != 0: print("=====" * 20) print("<< Epoch {} average >>".format(ep) + " " * 20) string = "" had_write = False for i, (key, loss) in enumerate( model.getLoss(stage=1, normalize=True).items()): if loss == 0.0: continue loss = round(loss, 10) if i % 2 == 1: string += "{:>20}: {:>15} \t".format(key[14:], str(loss)) print(string) had_write = True else: string = "{:>20}: {:>15} \t".format(key[14:], str(loss)) had_write = False if had_write == False: print(string) model.finishEpoch() # Save if not os.path.exists(args.root_folder): os.mkdir(args.root_folder) if not os.path.exists(os.path.join(args.root_folder, 'models_1st')): os.mkdir(os.path.join(args.root_folder, 'models_1st')) model.save(path=os.path.join(args.root_folder, 'models_1st', str(ep) + '.pth'))
def main(opts): # Data load loader = sunnerData.DataLoader(sunnerData.ImageDataset( root=[[opts.path]], transform=transforms.Compose([ sunnertransforms.Resize((opts.resolution, opts.resolution)), sunnertransforms.ToTensor(), sunnertransforms.ToFloat(), sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW), sunnertransforms.Normalize() ])), batch_size=opts.batch_size, shuffle=True, drop_last=True) # model generation start_epoch = 0 G = Generator_stylegan2(fmap_base=opts.fmap_base, resol=opts.resolution, mapping_layers=opts.mapping_layers, opts=opts, return_dlatents=True) D = Discriminator_stylegan2(fmap_base=opts.fmap_base, resol=opts.resolution, structure='resnet') # pre-trained weight loading if os.path.exists(opts.resume): INFO("Load the pre-trained weight!") state = torch.load(opts.resume) G.load_state_dict(state['G']) D.load_state_dict(state['D']) start_epoch = state['start_epoch'] else: INFO("pre-trained weight error") # multiple GPU support if (torch.cuda.device_count() > 1): INFO("multiple GPU detected! Total " + str(torch.cuda.device_count()) + '\t GPUs!') G = torch.nn.DataParrlel(G) D = torch.nn.DataParallel(D) G.to(opts.device) D.to(opts.device) # optimizer, scheduler lr_D = 0.0015 lr_G = 0.0015 optim_D = torch.optim.Adam(D.parameters(), lr=lr_D, betas=(0.9, 0.999)) params_G = [{ "params": G.g_synthesis.parameters() }, { "params": G.g_mapping.parameters(), "lr": lr_G * 0.01 }] optim_G = torch.optim.Adam(params_G, lr=lr_G, betas=(0.9, 0.999)) scheduler_D = optim.lr_scheduler.ExponentialLR(optim_D, gamma=0.99) scheduler_G = optim.lr_scheduler.ExponentialLR(optim_G, gamma=0.99) # start training fix_z = torch.randn([opts.batch_size, 512]).to(opts.device) softplus = torch.nn.Softplus() Loss_D_list = [0.0] Loss_G_list = [0.0] for ep in range(start_epoch, opts.epoch): bar = tqdm(loader) loss_D_list = [] loss_G_list = [] for i, (real_img, ) in enumerate(bar): real_img = real_img.to(opts.device) latents = torch.randn([real_img.size(0), 512]).to(opts.device) # Discriminator Network real_img = real_img.to(opts.device) real_logit = D(real_img) fake_img, fake_dlatent = G(latents) fake_logit = D(fake_img.detach()) d_loss = softplus(fake_logit) d_loss = d_loss + softplus(-real_logit) r1_penalty = D_logistic_r1(real_img.detach(), D) d_loss = (d_loss + r1_penalty).mean() loss_D_list.append(d_loss.mean().item()) optim_D.zero_grad() d_loss.backward() optim_D.step() # Generator Network G.zero_grad() fake_scores_out = D(fake_img) _g_loss = softplus(-fake_scores_out) g_loss = _g_loss.mean() loss_G_list.append(g_loss.mean().item()) g_loss.backward() optim_G.step() bar.set_description("Epoch {} [{}, {}] [G]: {} [D]: {}".format( ep, i + 1, len(loader), loss_G_list[-1], loss_D_list[-1])) # save result Loss_G_list.append(np.mean(loss_G_list)) Loss_D_list.append(np.mean(loss_D_list)) with torch.no_grad(): fake_img = G(fix_z)[0].detach().cpu() save_image(fake_img, os.path.join(opts.det, 'images', str(ep) + '.png'), nrow=4, normalize=True) # save model state = { 'G': G.state_dict(), 'D': D.state_dict(), 'Loss_G': Loss_G_list, 'Loss_D': Loss_D_list, 'start_epoch': ep, } torch.save(state, os.path.join(opts.det, 'models', 'latest.pth')) scheduler_D.step() scheduler_G.step() Loss_D_list = Loss_D_list[1:] Loss_G_list = Loss_G_list[1:] plotLossCurve(opts, Loss_D_list, Loss_G_list)