def test(self): test_img = self.config["testImgRoot"] save_dir = self.config["testSamples"] batch_size = self.config["batchSize"] n_class = len(self.config["selectedStyleDir"]) print("%d classes"%n_class) # data # SpecifiedImages = None # if self.config["useSpecifiedImg"]: # SpecifiedImages = self.config["specifiedTestImg"] test_data = TestDataset(test_img,self.config["imCropSize"],batch_size) total = len(test_data) # models package = __import__(self.config["com_base"]+self.config["gScriptName"], fromlist=True) GClass = getattr(package, 'Generator') Gen = GClass(self.config["GConvDim"], self.config["GKS"], self.config["resNum"]) if self.config["cuda"] >=0: Gen = Gen.cuda() Gen.load_state_dict(torch.load(self.config["ckp_name"])) print('loaded trained models {}...!'.format(self.config["ckp_name"])) condition_labels = torch.ones((n_class, batch_size, 1)).long() for i in range(n_class): condition_labels[i,:,:] = condition_labels[i,:,:]*i if self.config["cuda"] >=0: condition_labels = condition_labels.cuda() start_time = time.time() Gen.eval() with torch.no_grad(): for _ in tqdm(range(total//batch_size)): content,img_name = test_data() final_res = None for i in range(n_class): if self.config["cuda"] >=0: content = content.cuda() res,_ = Gen(content, condition_labels[i,0,:]) if i ==0: final_res = res else: final_res = torch.cat([final_res,res],0) save_image(denorm(final_res.data), os.path.join(save_dir, '{}_step{}_v_{}.png'.format(img_name,self.config["checkpointStep"],self.config["version"])),nrow=n_class)#,nrow=self.batch_size) elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print("Elapsed [{}]".format(elapsed))
def train(self): ckpt_dir = self.config["projectCheckpoints"] sample_dir = self.config["projectSamples"] total_step = self.config["totalStep"] log_frep = self.config["logStep"] sample_freq = self.config["sampleStep"] model_freq = self.config["modelSaveStep"] lr_base = self.config["gLr"] beta1 = self.config["beta1"] beta2 = self.config["beta2"] lrDecayStep = self.config["lrDecayStep"] batch_size = self.config["batchSize"] prep_weights = self.config["layersWeight"] feature_w = self.config["featureWeight"] transform_w = self.config["transformWeight"] workers = self.config["dataloader_workers"] dStep = self.config["dStep"] gStep = self.config["gStep"] if self.config["useTensorboard"]: from utilities.utilities import build_tensorboard tensorboard_writer = build_tensorboard( self.config["projectSummary"]) print("prepare the dataloader...") content_loader = getLoader(self.config["content"], self.config["selectedContentDir"], self.config["imCropSize"], batch_size, "Content", workers) style_loader = getLoader(self.config["style"], self.config["selectedStyleDir"], self.config["imCropSize"], batch_size, "Style", workers) print("build models...") if self.config["mode"] == "train": package = __import__("components." + self.config["gScriptName"], fromlist=True) GClass = getattr(package, 'Generator') package = __import__("components." + self.config["dScriptName"], fromlist=True) DClass = getattr(package, 'Discriminator') elif self.config["mode"] == "finetune": package = __import__(self.config["com_base"] + self.config["gScriptName"], fromlist=True) GClass = getattr(package, 'Generator') package = __import__(self.config["com_base"] + self.config["dScriptName"], fromlist=True) DClass = getattr(package, 'Discriminator') Gen = GClass(self.config["GConvDim"], self.config["GKS"], self.config["resNum"]) Dis = DClass(self.config["DConvDim"], self.config["DKS"]) self.reporter.writeInfo("Generator structure:") self.reporter.writeModel(Gen.__str__()) # print(self.Decoder) self.reporter.writeInfo("Discriminator structure:") self.reporter.writeModel(Dis.__str__()) Transform = Transform_block().cuda() Gen = Gen.cuda() Dis = Dis.cuda() if self.config["mode"] == "finetune": model_path = os.path.join( self.config["projectCheckpoints"], "%d_Generator.pth" % self.config["checkpointStep"]) Gen.load_state_dict(torch.load(model_path)) print('loaded trained Generator model step {}...!'.format( self.config["checkpointStep"])) model_path = os.path.join( self.config["projectCheckpoints"], "%d_Discriminator.pth" % self.config["checkpointStep"]) Dis.load_state_dict(torch.load(model_path)) print('loaded trained Discriminator model step {}...!'.format( self.config["checkpointStep"])) print("build the optimizer...") # Loss and optimizer g_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, Gen.parameters()), lr_base, [beta1, beta2]) d_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, Dis.parameters()), lr_base, [beta1, beta2]) L1_loss = torch.nn.L1Loss() MSE_loss = torch.nn.MSELoss() # C_loss = torch.nn.BCEWithLogitsLoss() # L1_loss = torch.nn.SmoothL1Loss() Hinge_loss = torch.nn.ReLU() # Start with trained model if self.config["mode"] == "finetune": start = self.config["checkpointStep"] else: start = 0 total_step = total_step // (gStep + dStep) # win_rate = 0.8 # discr_success = 0.8 # alpha = 0.05 # real_labels = [] # fake_labels = [] # label_size = [[batch_size,1,378,378],[batch_size,1,180,180],[batch_size,1,36,36],[batch_size,1,4,4],[batch_size,1,1,1]] # for i in range(5): # real_label = torch.ones(label_size[i]).cuda() # fake_label = torch.zeros(label_size[i]).cuda() # # threshold = torch.zeros(size[i], device=self.device) # real_labels.append(real_label) # fake_labels.append(fake_label) # output_size = len(label_size) # output_size_float = float(output_size) # Data iterator print("prepare the dataloaders...") content_iter = iter(content_loader) style_iter = iter(style_loader) # Start time print('Start ====== training...') start_time = time.time() for step in range(start, total_step): Dis.train() Gen.train() # ================== Train D ================== # # Compute loss with real images for _ in range(dStep): try: content_images = next(content_iter) style_images = next(style_iter) except: style_iter = iter(style_loader) content_iter = iter(content_loader) style_images = next(style_iter) content_images = next(content_iter) style_images = style_images.cuda() content_images = content_images.cuda() d_out = Dis(style_images) d_loss_real = Hinge_loss(1 - d_out).mean() d_out = Dis(content_images) d_loss_photo = Hinge_loss(1 + d_out).mean() fake_image, _ = Gen(content_images) d_out = Dis(fake_image.detach()) d_loss_fake = Hinge_loss(1 + d_out).mean() # Backward + Optimize d_loss = d_loss_real + d_loss_photo + d_loss_fake d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() # ================== Train G ================== # for _ in range(gStep): try: content_images = next(content_iter) except: content_iter = iter(content_loader) content_images = next(content_iter) content_images = content_images.cuda() fake_image, real_feature = Gen(content_images) fake_feature = Gen(fake_image, get_feature=True) fake_out = Dis(fake_image) g_feature_loss = L1_loss(fake_feature, real_feature) g_transform_loss = MSE_loss(Transform(content_images), Transform(fake_image)) g_loss_fake = -fake_out.mean() g_loss_fake = g_loss_fake + g_feature_loss * feature_w + g_transform_loss * transform_w g_optimizer.zero_grad() g_loss_fake.backward() g_optimizer.step() # Print out log info if (step + 1) % log_frep == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print( "Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, d_out_fake: {:.4f}, g_loss_fake: {:.4f}" .format(elapsed, step + 1, total_step, (step + 1), total_step, d_loss_real.item(), d_loss_fake.item(), g_loss_fake.item())) if self.config["useTensorboard"]: tensorboard_writer.add_scalar('data/d_loss_real', d_loss_real.item(), (step + 1)) tensorboard_writer.add_scalar('data/d_loss_fake', d_loss_fake.item(), (step + 1)) tensorboard_writer.add_scalar('data/d_loss', d_loss.item(), (step + 1)) tensorboard_writer.add_scalar('data/g_loss', g_loss_fake.item(), (step + 1)) tensorboard_writer.add_scalar('data/g_feature_loss', g_feature_loss, (step + 1)) tensorboard_writer.add_scalar('data/g_transform_loss', g_transform_loss, (step + 1)) # Sample images if (step + 1) % sample_freq == 0: print('Sample images {}_fake.jpg'.format(step + 1)) # fake_images,_ = Gen(content_images) fake_images = fake_image saved_image1 = torch.cat([ denorm(content_images), denorm(fake_images.data), denorm(style_images) ], 3) # saved_image2 = torch.cat([denorm(style_images),denorm(fake_images.data)],3) # wocao = torch.cat([saved_image1,saved_image2],2) save_image( saved_image1, os.path.join(sample_dir, '{}_fake.jpg'.format(step + 1))) # print("Transfer validation images") # num = 1 # for val_img in self.validation_data: # print("testing no.%d img"%num) # val_img = val_img.cuda() # fake_images,_ = Gen(val_img) # saved_val_image = torch.cat([denorm(val_img),denorm(fake_images)],3) # save_image(saved_val_image, # os.path.join(self.valres_path, '%d_%d.jpg'%((step+1),num))) # num +=1 # save_image(denorm(displaymask.data),os.path.join(self.sample_path, '{}_mask.png'.format(step + 1))) if (step + 1) % model_freq == 0: torch.save( Gen.state_dict(), os.path.join(ckpt_dir, '{}_Generator.pth'.format(step + 1))) torch.save( Dis.state_dict(), os.path.join(ckpt_dir, '{}_Discriminator.pth'.format(step + 1)))
def train(self): ckpt_dir = self.config["projectCheckpoints"] sample_dir = self.config["projectSamples"] total_step = self.config["totalStep"] log_frep = self.config["logStep"] sample_freq = self.config["sampleStep"] model_freq = self.config["modelSaveStep"] lr_base = self.config["gLr"] beta1 = self.config["beta1"] beta2 = self.config["beta2"] # lrDecayStep = self.config["lrDecayStep"] batch_size = self.config["batchSize"] n_class = len(self.config["selectedStyleDir"]) # prep_weights= self.config["layersWeight"] feature_w = self.config["featureWeight"] transform_w = self.config["transformWeight"] dStep = self.config["dStep"] gStep = self.config["gStep"] total_loader = self.dataloaders if self.config["useTensorboard"]: from utilities.utilities import build_tensorboard tensorboard_writer = build_tensorboard( self.config["projectSummary"]) print("build models...") if self.config["mode"] == "train": package = __import__("components." + self.config["gScriptName"], fromlist=True) GClass = getattr(package, 'Generator') package = __import__("components." + self.config["dScriptName"], fromlist=True) DClass = getattr(package, 'Discriminator') elif self.config["mode"] == "finetune": print("finetune load scripts from %s" % self.config["com_base"]) package = __import__(self.config["com_base"] + self.config["gScriptName"], fromlist=True) GClass = getattr(package, 'Generator') package = __import__(self.config["com_base"] + self.config["dScriptName"], fromlist=True) DClass = getattr(package, 'Discriminator') Gen = GClass(self.config["GConvDim"], self.config["GKS"], self.config["resNum"]) Dis = DClass(self.config["DConvDim"], self.config["DKS"]) self.reporter.writeInfo("Generator structure:") self.reporter.writeModel(Gen.__str__()) # print(self.Decoder) self.reporter.writeInfo("Discriminator structure:") self.reporter.writeModel(Dis.__str__()) Transform = Transform_block().cuda() Gen = Gen.cuda() Dis = Dis.cuda() if self.config["mode"] == "finetune": model_path = os.path.join( self.config["projectCheckpoints"], "%d_Generator.pth" % self.config["checkpointStep"]) Gen.load_state_dict(torch.load(model_path)) print('loaded trained Generator model step {}...!'.format( self.config["checkpointStep"])) model_path = os.path.join( self.config["projectCheckpoints"], "%d_Discriminator.pth" % self.config["checkpointStep"]) Dis.load_state_dict(torch.load(model_path)) print('loaded trained Discriminator model step {}...!'.format( self.config["checkpointStep"])) print("build the optimizer...") # Loss and optimizer g_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, Gen.parameters()), lr_base, [beta1, beta2]) d_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, Dis.parameters()), lr_base, [beta1, beta2]) L1_loss = torch.nn.L1Loss() MSE_loss = torch.nn.MSELoss() Hinge_loss = torch.nn.ReLU().cuda() # L1_loss = torch.nn.SmoothL1Loss() # Start with trained model if self.config["mode"] == "finetune": start = self.config["checkpointStep"] else: start = 0 output_size = Dis.get_outputs_len() # Data iterator print("prepare the dataloaders...") # total_iter = iter(total_loader) # prefetcher = data_prefetcher(total_loader) # input, target = prefetcher.next() # style_iter = iter(style_loader) print("prepare the fixed labels...") fix_label = [i for i in range(n_class)] fix_label = torch.tensor(fix_label).long().cuda() # fix_label = fix_label.view(n_class,1) # fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1) # Start time import datetime print("Start to train at %s" % (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) print('Start ====== training...') start_time = time.time() for step in range(start, total_step): Dis.train() Gen.train() # ================== Train D ================== # # Compute loss with real images for _ in range(dStep): # start_time = time.time() # try: # # content_images = next(content_iter) # # style_images,label = next(style_iter) # content_images,style_images,label = next(total_iter) # except: # # style_iter = iter(style_loader) # # content_iter = iter(content_loader) # # style_images,label = next(style_iter) # # content_images = next(content_iter) # total_iter = iter(total_loader) # content_images,style_images,label = next(total_iter) # label = label.view(batch_size,1) content_images, style_images, label = total_loader.next() label = label.long() # style_images = style_images.cuda() # content_images = content_images.cuda() # label = label.cuda() # elapsed = time.time() - start_time # elapsed = str(datetime.timedelta(seconds=elapsed)) # print("data load time %s"%elapsed) # start_time = time.time() d_out = Dis(style_images, label) d_loss_real = 0 for i in range(output_size): temp = Hinge_loss(1 - d_out[i]).mean() # temp *= prep_weights[i] d_loss_real += temp d_loss_photo = 0 d_out = Dis(content_images, label) for i in range(output_size): temp = Hinge_loss(1 + d_out[i]).mean() # temp *= prep_weights[i] d_loss_photo += temp # label = label.view(batch_size,1) # style_labels = torch.zeros(batch_size, n_class).cuda().scatter_(1, label, 1) fake_image, _ = Gen(content_images, label) d_out = Dis(fake_image.detach(), label) d_loss_fake = 0 for i in range(output_size): temp = Hinge_loss(1 + d_out[i]).mean() # temp *= prep_weights[i] d_loss_fake += temp # Backward + Optimize d_loss = d_loss_real + d_loss_photo + d_loss_fake d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() # elapsed = time.time() - start_time # elapsed = str(datetime.timedelta(seconds=elapsed)) # print("inference time %s"%elapsed) # ================== Train G ================== # for _ in range(gStep): # try: # # content_images = next(content_iter) # # style_images,label = next(style_iter) # content_images,_,_ = next(total_iter) # except: # # style_iter = iter(style_loader) # # content_iter = iter(content_loader) # # style_images,label = next(style_iter) # # content_images = next(content_iter) # total_iter = iter(total_loader) # content_images,_,_ = next(total_iter) content_images, _, _ = total_loader.next() # content_images = content_images.cuda() # label = label.view(batch_size,1) # style_labels = torch.zeros(batch_size, n_class).cuda().scatter_(1, label, 1) # fake_image,real_feature = Gen(content_images,style_labels) fake_image, real_feature = Gen(content_images, label) fake_feature = Gen(fake_image, get_feature=True) d_out = Dis(fake_image, label.long()) g_feature_loss = L1_loss(fake_feature, real_feature) g_transform_loss = MSE_loss(Transform(content_images), Transform(fake_image)) g_loss_fake = 0 for i in range(output_size): temp = -d_out[i].mean() # temp *= prep_weights[i] g_loss_fake += temp # backward & optimize g_loss_fake = g_loss_fake + g_feature_loss * feature_w + g_transform_loss * transform_w g_optimizer.zero_grad() g_loss_fake.backward() g_optimizer.step() # Print out log info if (step + 1) % log_frep == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print( "[{}], Elapsed [{}], Step [{}/{}], d_out_real: {:.4f}, d_out_fake: {:.4f}, g_loss_fake: {:.4f}" .format(self.config["version"], elapsed, step + 1, total_step, d_loss_real.item(), d_loss_fake.item(), g_loss_fake.item())) if self.config["useTensorboard"]: tensorboard_writer.add_scalar('data/d_loss_real', d_loss_real.item(), (step + 1)) tensorboard_writer.add_scalar('data/d_loss_fake', d_loss_fake.item(), (step + 1)) tensorboard_writer.add_scalar('data/d_loss', d_loss.item(), (step + 1)) tensorboard_writer.add_scalar('data/g_loss', g_loss_fake.item(), (step + 1)) tensorboard_writer.add_scalar('data/g_feature_loss', g_feature_loss, (step + 1)) tensorboard_writer.add_scalar('data/g_transform_loss', g_transform_loss, (step + 1)) # Sample images # if (step +1 ) % sample_freq == 0: if (step + 1) % sample_freq == 0: print('Sample images {}_fake.jpg'.format(step + 1)) Gen.eval() with torch.no_grad(): for batch_item in range(batch_size): for class_item in range(n_class): fake_images, _ = Gen( content_images[[batch_item], :, :, :], fix_label[:, class_item]) if class_item == 0: oneimg_res = torch.cat([ denorm( content_images[[batch_item], :, :, :]), denorm(fake_images.data) ], 3) else: oneimg_res = torch.cat( [oneimg_res, denorm(fake_images.data)], 3) if batch_item == 0: batch_image = oneimg_res else: batch_image = torch.cat([batch_image, oneimg_res], 2) save_image(batch_image, os.path.join(sample_dir, '{}_fake.jpg'.format(step + 1)), nrow=1) del batch_image del oneimg_res del fake_images # print("Transfer validation images") # num = 1 # for val_img in self.validation_data: # print("testing no.%d img"%num) # val_img = val_img.cuda() # fake_images,_ = Gen(val_img) # saved_val_image = torch.cat([denorm(val_img),denorm(fake_images)],3) # save_image(saved_val_image, # os.path.join(self.valres_path, '%d_%d.jpg'%((step+1),num))) # num +=1 # save_image(denorm(displaymask.data),os.path.join(self.sample_path, '{}_mask.png'.format(step + 1))) if (step + 1) % model_freq == 0: print("Save step %d model checkpoints!" % (step + 1)) torch.save( Gen.state_dict(), os.path.join(ckpt_dir, '{}_Generator.pth'.format(step + 1))) torch.save( Dis.state_dict(), os.path.join(ckpt_dir, '{}_Discriminator.pth'.format(step + 1)))