def build_model(self): # self.G = net.Generator() # 构建两个判别器(二分类器) D_X, D_Y self.D_A = net.Discriminator(self.img_size, self.d_conv_dim, self.d_repeat_num, self.norm) self.D_B = net.Discriminator(self.img_size, self.d_conv_dim, self.d_repeat_num, self.norm) # 初始化网络参数,apply为从nn.module继承 self.G.apply(self.weights_init_xavier) self.D_A.apply(self.weights_init_xavier) self.D_B.apply(self.weights_init_xavier) # 从checkpoint文件中加载网络参数 self.load_checkpoint() # 循环一致性损失Cycle consistency loss self.criterionL1 = torch.nn.L1Loss() # 感知损失Perceptual loss self.criterionL2 = torch.nn.MSELoss() if self.device=='cuda': self.criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.FloatTensor) else: self.criterionGAN = GANLoss(use_lsgan=True, tensor=torch.FloatTensor) self.vgg = net.vgg16(pretrained=True) # 妆容损失makeup loss self.criterionHis = HistogramLoss() # Optimizers 优化器,迭代优化生成器和判别器的参数 self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_A_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D_A.parameters()), self.d_lr, [self.beta1, self.beta2]) self.d_B_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D_B.parameters()), self.d_lr, [self.beta1, self.beta2]) # Print networks self.print_network(self.G, 'G') self.print_network(self.D_A, 'D_A') self.print_network(self.D_B, 'D_B') if torch.cuda.is_available(): self.device = "cuda" if torch.cuda.device_count() > 1: self.G = nn.DataParallel(self.G) self.D_A = nn.DataParallel(self.D_A) self.D_B = nn.DataParallel(self.D_B) self.vgg = nn.DataParallel(self.vgg) self.criterionHis = nn.DataParallel(self.criterionHis) self.criterionGAN = nn.DataParallel(self.criterionGAN) self.criterionL1 = nn.DataParallel(self.criterionL1) self.criterionL2 = nn.DataParallel(self.criterionL2) self.criterionGAN = nn.DataParallel(self.criterionGAN) self.G.cuda() self.vgg.cuda() self.criterionHis.cuda() self.criterionGAN.cuda() self.criterionL1.cuda() self.criterionL2.cuda() self.D_A.cuda() self.D_B.cuda()
def build_model(self): # self.G = net.Generator() self.D_A = net.Discriminator(self.img_size, self.d_conv_dim, self.d_repeat_num, self.norm) self.D_B = net.Discriminator(self.img_size, self.d_conv_dim, self.d_repeat_num, self.norm) self.load_checkpoint() self.criterionL1 = torch.nn.L1Loss() self.criterionL2 = torch.nn.MSELoss() self.criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.FloatTensor) self.vgg = net.vgg16(pretrained=True) self.criterionHis = HistogramLoss() # Optimizers self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_A_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.D_A.parameters()), self.d_lr, [self.beta1, self.beta2]) self.d_B_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.D_B.parameters()), self.d_lr, [self.beta1, self.beta2]) # Print networks self.print_network(self.G, 'G') self.print_network(self.D_A, 'D_A') self.print_network(self.D_B, 'D_B') if torch.cuda.is_available(): self.device = "cuda" if torch.cuda.device_count() > 1: self.G = nn.DataParallel(self.G) self.D_A = nn.DataParallel(self.D_A) self.D_B = nn.DataParallel(self.D_B) self.vgg = nn.DataParallel(self.vgg) self.criterionHis = nn.DataParallel(self.criterionHis) self.criterionGAN = nn.DataParallel(self.criterionGAN) self.criterionL1 = nn.DataParallel(self.criterionL1) self.criterionL2 = nn.DataParallel(self.criterionL2) self.criterionGAN = nn.DataParallel(self.criterionGAN) self.G.cuda() self.vgg.cuda() self.criterionHis.cuda() self.criterionGAN.cuda() self.criterionL1.cuda() self.criterionL2.cuda() self.D_A.cuda() self.D_B.cuda()
def build_model(self): # Define generators and discriminators if self.whichG=='normal': self.G = net.Generator_makeup(self.g_conv_dim, self.g_repeat_num) if self.whichG=='branch': self.G = net.Generator_branch(self.g_conv_dim, self.g_repeat_num) for i in self.cls: setattr(self, "D_" + i, net.Discriminator(self.img_size, self.d_conv_dim, self.d_repeat_num, self.norm)) self.criterionL1 = torch.nn.L1Loss() self.criterionL2 = torch.nn.MSELoss() self.criterionGAN = GANLoss(use_lsgan=True, tensor =torch.cuda.FloatTensor) self.vgg=models.vgg16(pretrained=True) # Optimizers self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) for i in self.cls: setattr(self, "d_" + i + "_optimizer", \ torch.optim.Adam(filter(lambda p: p.requires_grad, getattr(self, "D_" + i).parameters()), \ self.d_lr, [self.beta1, self.beta2])) # Weights initialization self.G.apply(self.weights_init_xavier) for i in self.cls: getattr(self, "D_" + i).apply(self.weights_init_xavier) if torch.cuda.is_available(): self.G.cuda() self.vgg.cuda() for i in self.cls: getattr(self, "D_" + i).cuda()
def build_model(self): # Define generators and discriminators if self.whichG == 'normal': self.G = net.Generator_makeup(self.g_conv_dim, self.g_repeat_num) if self.whichG == 'branch': self.G = net.Generator_branch(self.g_conv_dim, self.g_repeat_num) for i in self.cls: setattr( self, "D_" + i, net.Discriminator(self.img_size, self.d_conv_dim, self.d_repeat_num, self.norm)) self.criterionL1 = torch.nn.L1Loss() self.criterionL2 = torch.nn.MSELoss() self.criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.FloatTensor) self.vgg = net.VGG() self.vgg.load_state_dict(torch.load('addings/vgg_conv.pth')) # self.vgg = models.vgg19_bn(pretrained=True) # Optimizers self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) for i in self.cls: setattr(self, "d_" + i + "_optimizer", \ torch.optim.Adam(filter(lambda p: p.requires_grad, getattr(self, "D_" + i).parameters()), \ self.d_lr, [self.beta1, self.beta2])) # Weights initialization self.G.apply(self.weights_init_xavier) for i in self.cls: getattr(self, "D_" + i).apply(self.weights_init_xavier) # Print networks self.print_network(self.G, 'G') for i in self.cls: self.print_network(getattr(self, "D_" + i), "D_" + i) """ if torch.cuda.device_count() > 1: self.G = torch.nn.DataParallel(self.G) self.vgg = torch.nn.DataParallel(self.vgg) for i in self.cls: setattr(self, "D_" + i, torch.nn.DataParallel(getattr(self, "D_" + i))) self.G.to(self.device) self.vgg.to(self.device) for i in self.cls: getattr(self, "D_" + i).to(self.device) """ if torch.cuda.is_available(): self.G.cuda() self.vgg.cuda() for i in self.cls: getattr(self, "D_" + i).cuda()
def build_model(self): # Define generators and discriminators self.E = network.Encoder(self.e_conv_dim) self.G = network.Generator(self.g_conv_dim) for i in self.cls: setattr( self, "D_" + i, net.Discriminator(self.img_size, self.d_conv_dim, self.d_repeat_num, self.norm)) # Define vgg for perceptual loss self.vgg = net.VGG() self.vgg.load_state_dict(torch.load('addings/vgg_conv.pth')) # Define loss self.criterionL1 = torch.nn.L1Loss() self.criterionL2 = torch.nn.MSELoss() self.criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.FloatTensor) # Optimizers self.e_optimizer = torch.optim.Adam(self.E.parameters(), self.e_lr, [self.beta1, self.beta2]) self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) for i in self.cls: setattr(self, "d_" + i + "_optimizer", \ torch.optim.Adam(filter(lambda p: p.requires_grad, getattr(self, "D_" + i).parameters()), \ self.d_lr, [self.beta1, self.beta2])) # Weights initialization self.E.apply(self.weights_init_xavier) self.G.apply(self.weights_init_xavier) for i in self.cls: getattr(self, "D_" + i).apply(self.weights_init_xavier) # Print networks self.print_network(self.E, 'E') self.print_network(self.G, 'G') for i in self.cls: self.print_network(getattr(self, "D_" + i), "D_" + i) if torch.cuda.is_available(): self.E.cuda() self.G.cuda() self.vgg.cuda() for i in self.cls: getattr(self, "D_" + i).cuda()
def build_model(self): # Define generators and discriminators self.G_A = net.Generator(self.g_conv_dim, self.g_repeat_num) self.G_B = net.Generator(self.g_conv_dim, self.g_repeat_num) self.D_A = net.Discriminator(self.img_size, self.d_conv_dim, self.d_repeat_num) self.D_B = net.Discriminator(self.img_size, self.d_conv_dim, self.d_repeat_num) self.criterionL1 = torch.nn.L1Loss() self.criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.FloatTensor) # Optimizers self.g_optimizer = torch.optim.Adam( itertools.chain(self.G_A.parameters(), self.G_B.parameters()), self.g_lr, [self.beta1, self.beta2]) self.d_A_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.D_A.parameters()), self.d_lr, [self.beta1, self.beta2]) self.d_B_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.D_B.parameters()), self.d_lr, [self.beta1, self.beta2]) self.G_A.apply(self.weights_init_xavier) self.D_A.apply(self.weights_init_xavier) self.G_B.apply(self.weights_init_xavier) self.D_B.apply(self.weights_init_xavier) # Print networks # self.print_network(self.E, 'E') self.print_network(self.G_A, 'G_A') self.print_network(self.D_A, 'D_A') self.print_network(self.G_B, 'G_B') self.print_network(self.D_B, 'D_B') if torch.cuda.is_available(): self.G_A.cuda() self.G_B.cuda() self.D_A.cuda() self.D_B.cuda()
class Solver(Track): def __init__(self, config, device="cpu", data_loader=None, inference=False): self.G = net.Generator() if inference: self.G.load_state_dict(torch.load(inference, map_location=torch.device(device))) self.G = self.G.to(device).eval() return self.start_time = time.time() self.checkpoint = config.MODEL.WEIGHTS self.log_path = config.LOG.LOG_PATH self.result_path = os.path.join(self.log_path, config.LOG.VIS_PATH) self.snapshot_path = os.path.join(self.log_path, config.LOG.SNAPSHOT_PATH) self.log_step = config.LOG.LOG_STEP self.vis_step = config.LOG.VIS_STEP if device=='cuda': self.snapshot_step = config.LOG.SNAPSHOT_STEP // torch.cuda.device_count() else: self.snapshot_step = config.LOG.SNAPSHOT_STEP // 1 # Data loader self.data_loader_train = data_loader self.img_size = config.DATA.IMG_SIZE self.num_epochs = config.TRAINING.NUM_EPOCHS self.num_epochs_decay = config.TRAINING.NUM_EPOCHS_DECAY self.g_lr = config.TRAINING.G_LR self.d_lr = config.TRAINING.D_LR self.g_step = config.TRAINING.G_STEP self.beta1 = config.TRAINING.BETA1 self.beta2 = config.TRAINING.BETA2 self.lambda_idt = config.LOSS.LAMBDA_IDT self.lambda_A = config.LOSS.LAMBDA_A self.lambda_B = config.LOSS.LAMBDA_B self.lambda_his_lip = config.LOSS.LAMBDA_HIS_LIP self.lambda_his_skin = config.LOSS.LAMBDA_HIS_SKIN self.lambda_his_eye = config.LOSS.LAMBDA_HIS_EYE self.lambda_vgg = config.LOSS.LAMBDA_VGG # Hyper-parameteres self.d_conv_dim = config.MODEL.D_CONV_DIM self.d_repeat_num = config.MODEL.D_REPEAT_NUM self.norm = config.MODEL.NORM self.device = device self.build_model() super(Solver, self).__init__() # For generator def weights_init_xavier(self, m): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.xavier_normal(m.weight.data, gain=1.0) elif classname.find('Linear') != -1: init.xavier_normal(m.weight.data, gain=1.0) def print_network(self, model, name): num_params = 0 for p in model.parameters(): num_params += p.numel() print(name) print(model) print("The number of parameters: {}".format(num_params)) def de_norm(self, x): out = (x + 1) / 2 return out.clamp(0, 1) def build_model(self): # self.G = net.Generator() # 构建两个判别器(二分类器) D_X, D_Y self.D_A = net.Discriminator(self.img_size, self.d_conv_dim, self.d_repeat_num, self.norm) self.D_B = net.Discriminator(self.img_size, self.d_conv_dim, self.d_repeat_num, self.norm) # 初始化网络参数,apply为从nn.module继承 self.G.apply(self.weights_init_xavier) self.D_A.apply(self.weights_init_xavier) self.D_B.apply(self.weights_init_xavier) # 从checkpoint文件中加载网络参数 self.load_checkpoint() # 循环一致性损失Cycle consistency loss self.criterionL1 = torch.nn.L1Loss() # 感知损失Perceptual loss self.criterionL2 = torch.nn.MSELoss() if self.device=='cuda': self.criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.FloatTensor) else: self.criterionGAN = GANLoss(use_lsgan=True, tensor=torch.FloatTensor) self.vgg = net.vgg16(pretrained=True) # 妆容损失makeup loss self.criterionHis = HistogramLoss() # Optimizers 优化器,迭代优化生成器和判别器的参数 self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_A_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D_A.parameters()), self.d_lr, [self.beta1, self.beta2]) self.d_B_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D_B.parameters()), self.d_lr, [self.beta1, self.beta2]) # Print networks self.print_network(self.G, 'G') self.print_network(self.D_A, 'D_A') self.print_network(self.D_B, 'D_B') if torch.cuda.is_available(): self.device = "cuda" if torch.cuda.device_count() > 1: self.G = nn.DataParallel(self.G) self.D_A = nn.DataParallel(self.D_A) self.D_B = nn.DataParallel(self.D_B) self.vgg = nn.DataParallel(self.vgg) self.criterionHis = nn.DataParallel(self.criterionHis) self.criterionGAN = nn.DataParallel(self.criterionGAN) self.criterionL1 = nn.DataParallel(self.criterionL1) self.criterionL2 = nn.DataParallel(self.criterionL2) self.criterionGAN = nn.DataParallel(self.criterionGAN) self.G.cuda() self.vgg.cuda() self.criterionHis.cuda() self.criterionGAN.cuda() self.criterionL1.cuda() self.criterionL2.cuda() self.D_A.cuda() self.D_B.cuda() def load_checkpoint(self): G_path = os.path.join(self.checkpoint, 'G.pth') if os.path.exists(G_path): self.G.load_state_dict(torch.load(G_path, map_location=torch.device(self.device))) print('loaded trained generator {}..!'.format(G_path)) D_A_path = os.path.join(self.checkpoint, 'D_A.pth') if os.path.exists(D_A_path): self.D_A.load_state_dict(torch.load(D_A_path, map_location=torch.device(self.device))) print('loaded trained discriminator A {}..!'.format(D_A_path)) D_B_path = os.path.join(self.checkpoint, 'D_B.pth') if os.path.exists(D_B_path): self.D_B.load_state_dict(torch.load(D_B_path, map_location=torch.device(self.device))) print('loaded trained discriminator B {}..!'.format(D_B_path)) def generate(self, org_A, ref_B, lms_A=None, lms_B=None, mask_A=None, mask_B=None, diff_A=None, diff_B=None, gamma=None, beta=None, ret=False): """org_A is content, ref_B is style""" res = self.G(org_A, ref_B, mask_A, mask_B, diff_A, diff_B, gamma, beta, ret) return res # mask attribute: 0:background 1:face 2:left-eyebrown 3:right-eyebrown 4:left-eye 5: right-eye 6: nose # 7: upper-lip 8: teeth 9: under-lip 10:hair 11: left-ear 12: right-ear 13: neck def test(self, real_A, mask_A, diff_A, real_B, mask_B, diff_B): cur_prama = None with torch.no_grad(): cur_prama = self.generate(real_A, real_B, None, None, mask_A, mask_B, diff_A, diff_B, ret=True) fake_A = self.generate(real_A, real_B, None, None, mask_A, mask_B, diff_A, diff_B, gamma=cur_prama[0], beta=cur_prama[1]) fake_A = fake_A.squeeze(0) # normalize min_, max_ = fake_A.min(), fake_A.max() fake_A.add_(-min_).div_(max_ - min_ + 1e-5) return ToPILImage()(fake_A.cpu()) def train(self): # The number of iterations per epoch self.iters_per_epoch = len(self.data_loader_train) # Start with trained model if exists g_lr = self.g_lr d_lr = self.d_lr start = 0 for self.e in range(start, self.num_epochs): # epoch for self.i, (source_input, reference_input) in enumerate(self.data_loader_train): # batch # image, mask, dist image_s, image_r = source_input[0].to(self.device), reference_input[0].to(self.device) mask_s, mask_r = source_input[1].to(self.device), reference_input[1].to(self.device) dist_s, dist_r = source_input[2].to(self.device), reference_input[2].to(self.device) self.track("data") # ================== Train D ================== # # training D_A, D_A aims to distinguish class B 判断是否是“真reference” y # Real out = self.D_A(image_r) self.track("D_A") d_loss_real = self.criterionGAN(out, True) self.track("D_A_loss") # Fake # 利用生成网络生成fake_y fake_A = self.G(image_s, image_r, mask_s, mask_r, dist_s, dist_r) self.track("G") # 判别网络的输入,判别网络的损失 requires_grad=False fake_A = Variable(fake_A.data).detach() out = self.D_A(fake_A) self.track("D_A_2") d_loss_fake = self.criterionGAN(out, False) self.track("D_A_loss_2") # Backward + Optimize # 判别器网络反向传播,更新网络参数 d_loss = (d_loss_real.mean() + d_loss_fake.mean()) * 0.5 self.d_A_optimizer.zero_grad() d_loss.backward(retain_graph=False) ##retain_graph=False 释放计算图 self.d_A_optimizer.step() # Logging self.loss = {} self.loss['D-A-loss_real'] = d_loss_real.mean().item() # training D_B, D_B aims to distinguish class A 判断是否是“真source” x # Real out = self.D_B(image_s) d_loss_real = self.criterionGAN(out, True) # Fake 利用生成网络生成fake_x self.track("G-before") fake_B = self.G(image_r, image_s, mask_r, mask_s, dist_r, dist_s) self.track("G-2") fake_B = Variable(fake_B.data).detach() out = self.D_B(fake_B) d_loss_fake = self.criterionGAN(out, False) # Backward + Optimize d_loss = (d_loss_real.mean() + d_loss_fake.mean()) * 0.5 self.d_B_optimizer.zero_grad() d_loss.backward(retain_graph=False) self.d_B_optimizer.step() # Logging self.loss['D-B-loss_real'] = d_loss_real.mean().item() # self.track("Discriminator backward") # ================== Train G ================== # if (self.i + 1) % self.g_step == 0: # identity loss assert self.lambda_idt > 0 # G should be identity if ref_B or org_A is fed idt_A = self.G(image_s, image_s, mask_s, mask_s, dist_s, dist_s) idt_B = self.G(image_r, image_r, mask_r, mask_r, dist_r, dist_r) loss_idt_A = self.criterionL1(idt_A, image_s) * self.lambda_A * self.lambda_idt loss_idt_B = self.criterionL1(idt_B, image_r) * self.lambda_B * self.lambda_idt # loss_idt loss_idt = (loss_idt_A + loss_idt_B) * 0.5 # loss_idt = loss_idt_A * 0.5 # self.track("Identical") # GAN loss D_A(G_A(A)) # fake_A in class B, # 生成器对抗损失 L_G^adv fake_A = self.G(image_s, image_r, mask_s, mask_r, dist_s, dist_r) pred_fake = self.D_A(fake_A) g_A_loss_adv = self.criterionGAN(pred_fake, True) # GAN loss D_B(G_B(B)) fake_B = self.G(image_r, image_s, mask_r, mask_s, dist_r, dist_s) pred_fake = self.D_B(fake_B) g_B_loss_adv = self.criterionGAN(pred_fake, True) # self.track("Generator forward") # color_histogram loss # 各局部颜色直方图损失 Makeup loss g_A_loss_his = 0 g_B_loss_his = 0 g_A_lip_loss_his = self.criterionHis( fake_A, image_r, mask_s[:, 0], mask_r[:, 0] ) * self.lambda_his_lip g_B_lip_loss_his = self.criterionHis( fake_B, image_s, mask_r[:, 0], mask_s[:, 0] ) * self.lambda_his_lip g_A_loss_his += g_A_lip_loss_his g_B_loss_his += g_B_lip_loss_his g_A_skin_loss_his = self.criterionHis( fake_A, image_r, mask_s[:, 1], mask_r[:, 1] ) * self.lambda_his_skin g_B_skin_loss_his = self.criterionHis( fake_B, image_s, mask_r[:, 1], mask_s[:, 1] ) * self.lambda_his_skin g_A_loss_his += g_A_skin_loss_his g_B_loss_his += g_B_skin_loss_his g_A_eye_loss_his = self.criterionHis( fake_A, image_r, mask_s[:, 2], mask_r[:, 2] ) * self.lambda_his_eye g_B_eye_loss_his = self.criterionHis( fake_B, image_s, mask_r[:, 2], mask_s[:, 2] ) * self.lambda_his_eye g_A_loss_his += g_A_eye_loss_his g_B_loss_his += g_B_eye_loss_his # self.track("Generator histogram") # cycle loss # fake_A: fake_x/source rec_A = self.G(fake_A, image_s, mask_s, mask_s, dist_s, dist_s) rec_B = self.G(fake_B, image_r, mask_r, mask_r, dist_r, dist_r) g_loss_rec_A = self.criterionL1(rec_A, image_s) * self.lambda_A g_loss_rec_B = self.criterionL1(rec_B, image_r) * self.lambda_B # self.track("Generator recover") # vgg loss # Perceptual loss vgg_s = self.vgg(image_s) vgg_s = Variable(vgg_s.data).detach() vgg_fake_A = self.vgg(fake_A) g_loss_A_vgg = self.criterionL2(vgg_fake_A, vgg_s) * self.lambda_A * self.lambda_vgg # self.track("Generator vgg") vgg_r = self.vgg(image_r) vgg_r = Variable(vgg_r.data).detach() vgg_fake_B = self.vgg(fake_B) g_loss_B_vgg = self.criterionL2(vgg_fake_B, vgg_r) * self.lambda_B * self.lambda_vgg loss_rec = (g_loss_rec_A + g_loss_rec_B + g_loss_A_vgg + g_loss_B_vgg) * 0.5 # loss_rec = (g_loss_rec_A + g_loss_A_vgg) * 0.5 # Combined loss g_loss = (g_A_loss_adv + g_B_loss_adv + loss_rec + loss_idt + g_A_loss_his + g_B_loss_his).mean() # g_loss = (g_A_loss_adv + loss_rec + loss_idt + g_A_loss_his).mean() self.g_optimizer.zero_grad() g_loss.backward(retain_graph=False) self.g_optimizer.step() # self.track("Generator backward") # Logging self.loss['G-A-loss-adv'] = g_A_loss_adv.mean().item() self.loss['G-B-loss-adv'] = g_A_loss_adv.mean().item() self.loss['G-loss-org'] = g_loss_rec_A.mean().item() self.loss['G-loss-ref'] = g_loss_rec_B.mean().item() self.loss['G-loss-idt'] = loss_idt.mean().item() self.loss['G-loss-img-rec'] = (g_loss_rec_A + g_loss_rec_B).mean().item() self.loss['G-loss-vgg-rec'] = (g_loss_A_vgg + g_loss_B_vgg).mean().item() self.loss['G-loss-img-rec'] = g_loss_rec_A.mean().item() self.loss['G-loss-vgg-rec'] = g_loss_A_vgg.mean().item() self.loss['G-A-loss-his'] = g_A_loss_his.mean().item() # Print out log info if (self.i + 1) % self.log_step == 0: self.log_terminal() #plot the figures for key_now in self.loss.keys(): plot_fig.plot(key_now, self.loss[key_now]) #save the images if (self.i) % self.vis_step == 0: print("Saving middle output...") self.vis_train([image_s, image_r, fake_A, rec_A, mask_s[:, :, 0], mask_r[:, :, 0]]) # Save model checkpoints if (self.i) % self.snapshot_step == 0: self.save_models() if (self.i % 100 == 99): plot_fig.flush(self.log_path) plot_fig.tick() # Decay learning rate if (self.e+1) > (self.num_epochs - self.num_epochs_decay): g_lr -= (self.g_lr / float(self.num_epochs_decay)) d_lr -= (self.d_lr / float(self.num_epochs_decay)) self.update_lr(g_lr, d_lr) print('Decay learning rate to g_lr: {}, d_lr:{}.'.format(g_lr, d_lr)) def update_lr(self, g_lr, d_lr): for param_group in self.g_optimizer.param_groups: param_group['lr'] = g_lr for param_group in self.d_A_optimizer.param_groups: param_group['lr'] = d_lr for param_group in self.d_B_optimizer.param_groups: param_group['lr'] = d_lr def save_models(self): if not osp.exists(self.snapshot_path): os.makedirs(self.snapshot_path) torch.save( self.G.state_dict(), os.path.join( self.snapshot_path, '{}_{}_G.pth'.format(self.e + 1, self.i + 1))) torch.save( self.D_A.state_dict(), os.path.join( self.snapshot_path, '{}_{}_D_A.pth'.format(self.e + 1, self.i + 1))) torch.save( self.D_B.state_dict(), os.path.join( self.snapshot_path, '{}_{}_D_B.pth'.format(self.e + 1, self.i + 1))) def vis_train(self, img_train_list): # saving training results mode = "train_vis" img_train_list = torch.cat(img_train_list, dim=3) result_path_train = osp.join(self.result_path, mode) if not osp.exists(result_path_train): os.makedirs(result_path_train) save_path = os.path.join(result_path_train, '{}_{}_fake.jpg'.format(self.e, self.i)) save_image(self.de_norm(img_train_list.data), save_path, normalize=True) def log_terminal(self): elapsed = time.time() - self.start_time elapsed = str(datetime.timedelta(seconds=elapsed)) log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( elapsed, self.e+1, self.num_epochs, self.i+1, self.iters_per_epoch) for tag, value in self.loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) def to_var(self, x, requires_grad=True): if torch.cuda.is_available(): x = x.cuda() if not requires_grad: return Variable(x, requires_grad=requires_grad) else: return Variable(x)