class Pix2Pix: def __init__(self, args) -> None: self.lr = args.learning_rate self.LAMBDA = args.LAMBDA self.save = args.save self.batch_size = args.batch_size self.path = args.path self.n_epochs = args.epoch_num self.eval_interval = 10 self.G_image_loss = [] self.G_GAN_loss = [] self.G_total_loss = [] self.D_loss = [] self.netG = Generator().to("cuda") self.netD = Discriminator().to("cuda") self.optimizerG = flow.optim.Adam(self.netG.parameters(), lr=self.lr, betas=(0.5, 0.999)) self.optimizerD = flow.optim.Adam(self.netD.parameters(), lr=self.lr, betas=(0.5, 0.999)) self.criterionGAN = flow.nn.BCEWithLogitsLoss() self.criterionL1 = flow.nn.L1Loss() self.checkpoint_path = os.path.join(self.path, "checkpoint") self.test_images_path = os.path.join(self.path, "test_images") mkdirs(self.checkpoint_path, self.test_images_path) self.logger = init_logger(os.path.join(self.path, "log.txt")) def train(self): # init dataset x, y = load_facades() # flow.Tensor() bug in here x, y = np.ascontiguousarray(x), np.ascontiguousarray(y) self.fixed_inp = to_tensor(x[:self.batch_size].astype(np.float32)) self.fixed_target = to_tensor(y[:self.batch_size].astype(np.float32)) batch_num = len(x) // self.batch_size label1 = to_tensor(np.ones((self.batch_size, 1, 30, 30)), dtype=flow.float32) label0 = to_tensor(np.zeros((self.batch_size, 1, 30, 30)), dtype=flow.float32) for epoch_idx in range(self.n_epochs): self.netG.train() self.netD.train() start = time.time() # run every epoch to shuffle for batch_idx in range(batch_num): inp = to_tensor(x[batch_idx * self.batch_size:(batch_idx + 1) * self.batch_size].astype(np.float32)) target = to_tensor( y[batch_idx * self.batch_size:(batch_idx + 1) * self.batch_size].astype(np.float32)) # update D d_fake_loss, d_real_loss, d_loss = self.train_discriminator( inp, target, label0, label1) # update G g_gan_loss, g_image_loss, g_total_loss, g_out = self.train_generator( inp, target, label1) self.G_GAN_loss.append(g_gan_loss) self.G_image_loss.append(g_image_loss) self.G_total_loss.append(g_total_loss) self.D_loss.append(d_loss) if (batch_idx + 1) % self.eval_interval == 0: self.logger.info( "{}th epoch, {}th batch, d_fakeloss:{:>8.4f}, d_realloss:{:>8.4f}, ggan_loss:{:>8.4f}, gl1_loss:{:>8.4f}" .format( epoch_idx + 1, batch_idx + 1, d_fake_loss, d_real_loss, g_gan_loss, g_image_loss, )) self.logger.info("Time for epoch {} is {} sec.".format( epoch_idx + 1, time.time() - start)) if (epoch_idx + 1) % 2 * self.eval_interval == 0: # save .train() images # save .eval() images self._eval_generator_and_save_images(epoch_idx) if self.save: flow.save( self.netG.state_dict(), os.path.join(self.checkpoint_path, "pix2pix_g_{}".format(epoch_idx + 1)), ) flow.save( self.netD.state_dict(), os.path.join(self.checkpoint_path, "pix2pix_d_{}".format(epoch_idx + 1)), ) # save train loss and val error to plot np.save( os.path.join(self.path, "G_image_loss_{}.npy".format(self.n_epochs)), self.G_image_loss, ) np.save( os.path.join(self.path, "G_GAN_loss_{}.npy".format(self.n_epochs)), self.G_GAN_loss, ) np.save( os.path.join(self.path, "G_total_loss_{}.npy".format(self.n_epochs)), self.G_total_loss, ) np.save( os.path.join(self.path, "D_loss_{}.npy".format(self.n_epochs)), self.D_loss, ) self.logger.info("*************** Train done ***************** ") def train_generator(self, input, target, label1): g_out = self.netG(input) # First, G(A) should fake the discriminator fake_AB = flow.cat([input, g_out], 1) pred_fake = self.netD(fake_AB) gan_loss = self.criterionGAN(pred_fake, label1) # Second, G(A) = B l1_loss = self.criterionL1(g_out, target) # combine loss and calculate gradients g_loss = gan_loss + self.LAMBDA * l1_loss g_loss.backward() self.optimizerG.step() self.optimizerG.zero_grad() return ( to_numpy(gan_loss), to_numpy(self.LAMBDA * l1_loss), to_numpy(g_loss), to_numpy(g_out, False), ) def train_discriminator(self, input, target, label0, label1): g_out = self.netG(input) # Fake; stop backprop to the generator by detaching fake_B fake_AB = flow.cat([input, g_out.detach()], 1) pred_fake = self.netD(fake_AB) d_fake_loss = self.criterionGAN(pred_fake, label0) # Real real_AB = flow.cat([input, target], 1) pred_real = self.netD(real_AB) d_real_loss = self.criterionGAN(pred_real, label1) # combine loss and calculate gradients d_loss = (d_fake_loss + d_real_loss) * 0.5 d_loss.backward() self.optimizerD.step() self.optimizerD.zero_grad() return to_numpy(d_fake_loss), to_numpy(d_real_loss), to_numpy(d_loss) def _eval_generator_and_save_images(self, epoch_idx): results = self._eval_generator() save_images( results, to_numpy(self.fixed_inp, False), to_numpy(self.fixed_target, False), path=os.path.join(self.test_images_path, "testimage_{:02d}.png".format(epoch_idx + 1)), ) def _eval_generator(self): self.netG.eval() with flow.no_grad(): g_out = self.netG(self.fixed_inp) return to_numpy(g_out, False)
class DeformablePose_GAN(nn.Module): def __init__(self, opt): super(DeformablePose_GAN, self).__init__() # load generator and discriminator models # adding extra layers for larger image size nfilters_decoder = (512, 512, 512, 256, 128, 3) if max(opt.image_size) < 256 else (512, 512, 512, 512, 256, 128, 3) nfilters_encoder = (64, 128, 256, 512, 512, 512) if max(opt.image_size) < 256 else (64, 128, 256, 512, 512, 512, 512) if (opt.use_input_pose): input_nc = 3 + 2 * opt.pose_dim else: input_nc = 3 + opt.pose_dim self.batch_size = opt.batch_size self.num_stacks = opt.num_stacks self.pose_dim = opt.pose_dim if (opt.gen_type == 'stacked'): self.gen = Stacked_Generator(input_nc, opt.num_stacks, opt.image_size, opt.pose_dim, nfilters_encoder, nfilters_decoder, opt.warp_skip, use_input_pose=opt.use_input_pose) # hack to get better results pretrained_gen_path = '../exp/' + 'full_' + opt.dataset + '/models/gen_090.pkl' self.gen.generator.load_state_dict(torch.load(pretrained_gen_path)) print("Loaded generator from pretrained model ") elif (opt.gen_type == 'baseline'): self.gen = Deformable_Generator(input_nc, self.pose_dim, opt.image_size, nfilters_encoder, nfilters_decoder, opt.warp_skip, use_input_pose=opt.use_input_pose) else: raise Exception('Invalid gen_type') # discriminator also sees the output image for the target pose self.disc = Discriminator(input_nc + 3, use_input_pose=opt.use_input_pose) # self.disc_2 = Discriminator(6, use_input_pose=opt.use_input_pose) pretrained_disc_path = "/home/linlilang/pose-transfer/exp/baseline_market/models/disc_020.pkl" print("Loaded discriminator from pretrained model ") self.disc.load_state_dict(torch.load(pretrained_disc_path)) print('---------- Networks initialized -------------') # print_network(self.gen) # print_network(self.disc) print('-----------------------------------------------') # Setup the optimizers lr = opt.learning_rate self.disc_opt = torch.optim.Adam(self.disc.parameters(), lr=lr, betas=(0.5, 0.999)) # self.disc_opt_2 = torch.optim.Adam(self.disc_2.parameters(), lr=lr, betas=(0.5, 0.999)) self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999)) self.content_loss_layer = opt.content_loss_layer self.nn_loss_area_size = opt.nn_loss_area_size if self.content_loss_layer != 'none': self.content_model = resnet101(pretrained=True) # Setup the loss function for training # Network weight initialization self.gen.cuda() self.disc.cuda() # self.disc_2.cuda() self._nn_loss_area_size = opt.nn_loss_area_size # applying xavier_uniform, equivalent to glorot unigorm, as in Keras Defo GAN # skipping as models are pretrained # self.disc.apply(xavier_weights_init) # self.gen.apply(xavier_weights_init) self.ll_loss_criterion = torch.nn.L1Loss() # add code for intermediate supervision for the interpolated poses using pretrained pose-estimator def gen_update(self, input, target, other_inputs, opt): self.gen.zero_grad() if (opt['gen_type'] == 'stacked'): interpol_pose = other_inputs['interpol_pose'] interpol_warps = other_inputs['interpol_warps'] interpol_masks = other_inputs['interpol_masks'] outputs_gen = self.gen(input, interpol_pose, interpol_warps, interpol_masks) out_gen = outputs_gen[-1] else: warps = other_inputs['warps'] masks = other_inputs['masks'] out_gen, out_gen_2, out_gen_3 = self.gen(input, warps, masks) outputs_gen = [] inp_img, inp_pose, out_pose = pose_utils.get_imgpose( input, opt['use_input_pose'], opt['pose_dim']) inp_dis = torch.cat([inp_img, inp_pose, out_gen, out_pose], dim=1) out_dis = self.disc(inp_dis) inp_dis_2 = torch.cat([inp_img, inp_pose, out_gen_2, out_pose], dim=1) out_dis_2 = self.disc(inp_dis_2) inp_dis_3 = torch.cat([inp_img, inp_pose, out_gen_3, out_pose], dim=1) out_dis_3 = self.disc(inp_dis_3) # computing adversarial loss for it in range(out_dis.shape[0]): out = out_dis[it, :] all_ones = Variable(torch.ones((out.size(0))).cuda()) if it == 0: # ad_loss = nn.functional.binary_cross_entropy(out, all_ones) ad_loss = -torch.mean(torch.log(out + 1e-7)) else: # ad_loss += nn.functional.binary_cross_entropy(out, all_ones) ad_loss += -torch.mean(torch.log(out + 1e-7)) for it in range(out_dis_2.shape[0]): out_2 = out_dis_2[it, :] all_ones = Variable(torch.ones((out.size(0))).cuda()) ad_loss += -torch.mean(torch.log(out_2 + 1e-7)) for it in range(out_dis_3.shape[0]): out_3 = out_dis_3[it, :] all_ones = Variable(torch.ones((out.size(0))).cuda()) ad_loss += -torch.mean(torch.log(out_3 + 1e-7)) if self.content_loss_layer != 'none': content_out_gen = pose_utils.Feature_Extractor( self.content_model, input=out_gen, layer_name=self.content_loss_layer) content_target = pose_utils.Feature_Extractor( self.content_model, input=target, layer_name=self.content_loss_layer) ll_loss = self.nn_loss(content_out_gen, content_target, self.nn_loss_area_size, self.nn_loss_area_size) else: ll_loss = self.ll_loss_criterion(out_gen, target) ll_loss += self.ll_loss_criterion(out_gen, target) ll_loss += self.ll_loss_criterion(out_gen_2, target) ll_loss += self.ll_loss_criterion(out_gen_3, target) ad_loss = ad_loss * opt['gan_penalty_weight'] / self.batch_size ll_loss = ll_loss * opt['l1_penalty_weight'] total_loss = ad_loss + ll_loss total_loss.backward() self.gen_opt.step() self.gen_ll_loss = ll_loss.item() self.gen_ad_loss = ad_loss.item() self.gen_total_loss = total_loss.item() return out_gen, outputs_gen, [ self.gen_total_loss, self.gen_ll_loss, self.gen_ad_loss ] def dis_update(self, input, target, other_inputs, real_inp, real_target, opt): self.disc.zero_grad() if (opt['gen_type'] == 'stacked'): interpol_pose = other_inputs['interpol_pose'] interpol_warps = other_inputs['interpol_warps'] interpol_masks = other_inputs['interpol_masks'] out_gen = self.gen(input, interpol_pose, interpol_warps, interpol_masks) out_gen = out_gen[-1] else: warps = other_inputs['warps'] masks = other_inputs['masks'] out_gen, out_gen_2, out_gen_3 = self.gen(input, warps, masks) inp_img, inp_pose, out_pose = pose_utils.get_imgpose( input, opt['use_input_pose'], opt['pose_dim']) fake_disc_inp = torch.cat([inp_img, inp_pose, out_gen, out_pose], dim=1) r_inp_img, r_inp_pose, r_out_pose = pose_utils.get_imgpose( real_inp, opt['use_input_pose'], opt['pose_dim']) real_disc_inp = torch.cat( [r_inp_img, r_inp_pose, real_target, r_out_pose], dim=1) data_dis = torch.cat((real_disc_inp, fake_disc_inp), 0) res_dis = self.disc(data_dis) fake_disc_inp_2 = torch.cat([inp_img, inp_pose, out_gen_2, out_pose], dim=1) data_dis_2 = torch.cat((real_disc_inp, fake_disc_inp_2), 0) res_dis_2 = self.disc(data_dis_2) fake_disc_inp_3 = torch.cat([inp_img, inp_pose, out_gen_3, out_pose], dim=1) data_dis_3 = torch.cat((real_disc_inp, fake_disc_inp_3), 0) res_dis_3 = self.disc(data_dis_3) # print(res_dis.shape) for it in range(res_dis.shape[0]): out = res_dis[it, :] if (it < opt['batch_size']): out_true_n = out.size(0) # real inputs should be 1 # all1 = Variable(torch.ones((out_true_n)).cuda()) if it == 0: # ad_true_loss = nn.functional.binary_cross_entropy(out, all1) ad_true_loss = -torch.mean(torch.log(out + 1e-7)) else: # ad_true_loss += nn.functional.binary_cross_entropy(out, all1) ad_true_loss += -torch.mean(torch.log(out + 1e-7)) else: out_fake_n = out.size(0) # fake inputs should be 0, appear after batch_size iters # all0 = Variable(torch.zeros((out_fake_n)).cuda()) if it == opt['batch_size']: # ad_true_loss = -torch.mean(torch.log(out + 1e-7))= nn.functional.binary_cross_entropy(out, all0) ad_fake_loss = -torch.mean(torch.log(1 - out + 1e-7)) else: ad_fake_loss += -torch.mean(torch.log(1 - out + 1e-7)) for it in range(res_dis_2.shape[0]): out_2 = res_dis_2[it, :] if (it < opt['batch_size']): out_true_n_2 = out_2.size(0) # real inputs should be 1 # all1 = Variable(torch.ones((out_true_n)).cuda()) ad_true_loss += -torch.mean(torch.log(out_2 + 1e-7)) else: out_fake_n_2 = out_2.size(0) # fake inputs should be 0, appear after batch_size iters # all0 = Variable(torch.zeros((out_fake_n)).cuda()) ad_fake_loss += -torch.mean(torch.log(1 - out_2 + 1e-7)) for it in range(res_dis_3.shape[0]): out_3 = res_dis_3[it, :] if (it < opt['batch_size']): out_true_n_3 = out_3.size(0) # real inputs should be 1 # all1 = Variable(torch.ones((out_true_n)).cuda()) ad_true_loss += -torch.mean(torch.log(out_3 + 1e-7)) else: out_fake_n_3 = out_3.size(0) # fake inputs should be 0, appear after batch_size iters # all0 = Variable(torch.zeros((out_fake_n)).cuda()) ad_fake_loss += -torch.mean(torch.log(1 - out_3 + 1e-7)) ad_true_loss = ad_true_loss * opt[ 'gan_penalty_weight'] / self.batch_size ad_fake_loss = ad_fake_loss * opt[ 'gan_penalty_weight'] / self.batch_size ad_loss = ad_true_loss + ad_fake_loss loss = ad_loss loss.backward() self.disc_opt.step() self.dis_total_loss = loss.item() self.dis_true_loss = ad_true_loss.item() self.dis_fake_loss = ad_fake_loss.item() return [self.dis_total_loss, self.dis_true_loss, self.dis_fake_loss] def nn_loss(self, predicted, ground_truth, nh=3, nw=3): v_pad = nh // 2 h_pad = nw // 2 val_pad = nn.ConstantPad2d((v_pad, v_pad, h_pad, h_pad), -10000)(ground_truth) reference_tensors = [] for i_begin in range(0, nh): i_end = i_begin - nh + 1 i_end = None if i_end == 0 else i_end for j_begin in range(0, nw): j_end = j_begin - nw + 1 j_end = None if j_end == 0 else j_end sub_tensor = val_pad[:, :, i_begin:i_end, j_begin:j_end] reference_tensors.append(sub_tensor.unsqueeze(-1)) reference = torch.cat(reference_tensors, dim=-1) ground_truth = ground_truth.unsqueeze(dim=-1) predicted = predicted.unsqueeze(-1) abs = torch.abs(reference - predicted) # sum along channels norms = torch.sum(abs, dim=1) # min over neighbourhood loss, _ = torch.min(norms, dim=-1) # loss = torch.sum(loss)/self.batch_size loss = torch.mean(loss) return loss def resume(self, save_dir): last_model_name = pose_utils.get_model_list(save_dir, "gen") if last_model_name is None: return 1 self.gen.load_state_dict(torch.load(last_model_name)) epoch = int(last_model_name[-7:-4]) print('Resume gen from epoch %d' % epoch) last_model_name = pose_utils.get_model_list(save_dir, "dis") if last_model_name is None: return 1 epoch = int(last_model_name[-7:-4]) self.disc.load_state_dict(torch.load(last_model_name)) print('Resume disc from epoch %d' % epoch) return epoch def save(self, save_dir, epoch): gen_filename = os.path.join(save_dir, 'gen_{0:03d}.pkl'.format(epoch)) disc_filename = os.path.join(save_dir, 'disc_{0:03d}.pkl'.format(epoch)) torch.save(self.gen.state_dict(), gen_filename) torch.save(self.disc.state_dict(), disc_filename) def normalize_image(self, x): return x[:, 0:3, :, :]
class Pose_GAN(nn.Module): def __init__(self, opt): super(Pose_GAN, self).__init__() # load generator and discriminator models # adding extra layers for larger image size if(opt.checkMode == 0): nfilters_decoder = (512, 512, 512, 256, 128, 3) if max(opt.image_size) < 256 else (512, 512, 512, 512, 256, 128, 3) nfilters_encoder = (64, 128, 256, 512, 512, 512) if max(opt.image_size) < 256 else (64, 128, 256, 512, 512, 512, 512) else: nfilters_decoder = (128, 3) if max(opt.image_size) < 256 else (256, 128, 3) nfilters_encoder = (64, 128) if max(opt.image_size) < 256 else (64, 128, 256) if (opt.use_input_pose): input_nc = 3 + 2*opt.pose_dim else: input_nc = 3 + opt.pose_dim self.num_stacks = opt.num_stacks self.batch_size = opt.batch_size self.pose_dim = opt.pose_dim if(opt.gen_type=='stacked'): self.gen = Stacked_Generator(input_nc, opt.num_stacks, opt.pose_dim, nfilters_encoder, nfilters_decoder, use_input_pose=opt.use_input_pose) elif(opt.gen_type=='baseline'): self.gen = Generator(input_nc, nfilters_encoder, nfilters_decoder, use_input_pose=opt.use_input_pose) else: raise Exception('Invalid gen_type') # discriminator also sees the output image for the target pose self.disc = Discriminator(input_nc + 3, use_input_pose=opt.use_input_pose, checkMode=opt.checkMode) print('---------- Networks initialized -------------') print_network(self.gen) print_network(self.disc) print('-----------------------------------------------') # Setup the optimizers lr = opt.learning_rate self.disc_opt = torch.optim.Adam(self.disc.parameters(), lr=lr, betas=(0.5, 0.999)) self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999)) # Network weight initialization self.gen.cuda() self.disc.cuda() self.disc.apply(xavier_weights_init) self.gen.apply(xavier_weights_init) # Setup the loss function for training self.ll_loss_criterion = torch.nn.L1Loss() # add code for intermediate supervision for the interpolated poses using pretrained pose-estimator def gen_update(self, input, target, interpol_pose, opt): self.gen.zero_grad() if(opt['gen_type']=='stacked'): outputs_gen = self.gen(input, interpol_pose) out_gen = outputs_gen[-1] else: out_gen = self.gen(input) outputs_gen = [] inp_img, inp_pose, out_pose = pose_utils.get_imgpose(input, opt['use_input_pose'], opt['pose_dim']) inp_dis = torch.cat([inp_img, inp_pose, out_gen, out_pose], dim=1) out_dis = self.disc(inp_dis) # computing adversarial loss for it in range(out_dis.shape[0]): out = out_dis[it, :] all_ones = Variable(torch.ones((out.size(0))).cuda()) if it==0: # ad_loss = nn.functional.binary_cross_entropy(out, all_ones) ad_loss = -torch.mean(torch.log(out + 1e-7)) else: # ad_loss += nn.functional.binary_cross_entropy(out, all_ones) ad_loss += -torch.mean(torch.log(out + 1e-7) ) ll_loss = self.ll_loss_criterion(out_gen, target) ad_loss = ad_loss * opt['gan_penalty_weight'] / self.batch_size ll_loss = ll_loss * opt['l1_penalty_weight'] total_loss = ad_loss + ll_loss total_loss.backward() self.gen_opt.step() self.gen_ll_loss = ll_loss.item() self.gen_ad_loss = ad_loss.item() self.gen_total_loss = total_loss.item() return out_gen, outputs_gen, [self.gen_total_loss, self.gen_ll_loss, self.gen_ad_loss ] def dis_update(self, input, target, interpol_pose, real_inp, real_target, opt): self.disc.zero_grad() if (opt['gen_type'] == 'stacked'): out_gen = self.gen(input, interpol_pose) out_gen = out_gen[-1] else: out_gen = self.gen(input) inp_img, inp_pose, out_pose = pose_utils.get_imgpose(input, opt['use_input_pose'], opt['pose_dim']) fake_disc_inp = torch.cat([inp_img, inp_pose, out_gen, out_pose], dim=1) r_inp_img, r_inp_pose, r_out_pose = pose_utils.get_imgpose(real_inp, opt['use_input_pose'], opt['pose_dim']) real_disc_inp = torch.cat([r_inp_img, r_inp_pose, real_target, r_out_pose], dim=1) data_dis = torch.cat((real_disc_inp, fake_disc_inp), 0) res_dis = self.disc(data_dis) for it in range(res_dis.shape[0]): out = res_dis[it,:] if(it<opt['batch_size']): out_true_n = out.size(0) # real inputs should be 1 # all1 = Variable(torch.ones((out_true_n)).cuda()) if it == 0: # ad_true_loss = nn.functional.binary_cross_entropy(out, all1) ad_true_loss = -torch.mean(torch.log(out + 1e-7)) else: # ad_true_loss += nn.functional.binary_cross_entropy(out, all1) ad_true_loss += -torch.mean(torch.log(out + 1e-7)) else: out_fake_n = out.size(0) # fake inputs should be 0, appear after batch_size iters # all0 = Variable(torch.zeros((out_fake_n)).cuda()) if it == opt['batch_size']: ad_fake_loss = -torch.mean(torch.log(1- out + 1e-7)) else: ad_fake_loss += -torch.mean(torch.log(1 - out + 1e-7)) ad_true_loss = ad_true_loss*opt['gan_penalty_weight']/self.batch_size ad_fake_loss = ad_fake_loss*opt['gan_penalty_weight']/self.batch_size ad_loss = ad_true_loss + ad_fake_loss loss = ad_loss loss.backward() self.disc_opt.step() self.dis_total_loss = loss.item() self.dis_true_loss = ad_true_loss.item() self.dis_fake_loss = ad_fake_loss.item() return [self.dis_total_loss , self.dis_true_loss , self.dis_fake_loss ] def resume(self, save_dir): last_model_name = pose_utils.get_model_list(save_dir,"gen") if last_model_name is None: return 1 self.gen.load_state_dict(torch.load(last_model_name)) epoch = int(last_model_name[-7:-4]) print('Resume gen from epoch %d' % epoch) last_model_name = pose_utils.get_model_list(save_dir, "dis") if last_model_name is None: return 1 epoch = int(last_model_name[-7:-4]) self.disc.load_state_dict(torch.load(last_model_name)) print('Resume disc from epoch %d' % epoch) return epoch def save(self, save_dir, epoch): gen_filename = os.path.join(save_dir, 'gen_{0:03d}.pkl'.format(epoch)) disc_filename = os.path.join(save_dir, 'disc_{0:03d}.pkl'.format(epoch)) torch.save(self.gen.state_dict(), gen_filename) torch.save(self.disc.state_dict(), disc_filename) def normalize_image(self, x): return x[:,0:3,:,:]
summary.add_scalar(f'loss G/loss Overall', loss_overall.data.cpu().numpy(), iter_count) etime = time.time() - stime rtime = etime * (total_epoch_iter - iter_count) / (iter_count + eps) print( f'Epoch: {epoch+1:03d}/{num_epochs:03d}, Iter: {i+1:04d}/{total_iter:04d}, ', end='') print(f'Loss G: {loss_overall.data:.4f}, Loss D: {loss_D.data:.4f}, ', end='') print(f'Elapsed: {sec2time(etime)}, Remaining: {sec2time(rtime)}') if (i + 1) % 10 == 0: summary.add_image(f'image/sr_image', sr[0], iter_count) summary.add_image(f'image/lr_image', lr[0], iter_count) summary.add_image(f'image/hr_image', hr[0], iter_count) torch.save( G.state_dict(), f'./models/weights/G_epoch_{epoch+1}_loss_{loss_overall.data:.4f}.pth') torch.save( D.state_dict(), f'./models/weights/D_epoch_{epoch+1}_loss_{loss_D.data:.4f}.pth') if (epoch + 1) % 10 == 0: learning_rateG *= 0.5 learning_rateD *= 0.5 update_lr(optimizerG, learning_rateG) update_lr(optimizerD, learning_rateD)