class CycleGANModel(nn.Module): def __init__(self, num_iter=100, num_iter_decay=100, lambda_A=10, lambda_B=10, lambda_identity=0.5): super(CycleGANModel, self).__init__() self.name = None self.epoch_count = torch.tensor(1) ### self.num_iter = torch.tensor(num_iter) self.num_iter_decay = torch.tensor(num_iter_decay) self.lambda_A = torch.tensor(lambda_A) self.lambda_B = torch.tensor(lambda_B) self.lambda_identity = torch.tensor(lambda_identity) self.netG_A = define_G(num_res_blocks=9) self.netG_B = define_G(num_res_blocks=9) self.netD_A = define_D() self.netD_B = define_D() self.fake_A_pool = ImagePool(pool_size=50) self.fake_B_pool = ImagePool(pool_size=50) self.criterionGAN = define_GAN_loss() self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.optimizer_G_A = optim.Adam(self.netG_A.parameters(), lr=0.0002, betas=(0.5, 0.999)) self.optimizer_G_B = optim.Adam(self.netG_B.parameters(), lr=0.0002, betas=(0.5, 0.999)) self.optimizer_D_A = optim.Adam(self.netD_A.parameters(), lr=0.0002, betas=(0.5, 0.999)) self.optimizer_D_B = optim.Adam(self.netD_B.parameters(), lr=0.0002, betas=(0.5, 0.999)) lambda_rule = lambda epoch: 1.0 - max( 0, epoch + self.epoch_count - self.num_iter) / float( self.num_iter_decay + 1) self.scheduler_G_A = scheduler.LambdaLR(self.optimizer_G_A, lr_lambda=lambda_rule) self.scheduler_G_B = scheduler.LambdaLR(self.optimizer_G_B, lr_lambda=lambda_rule) self.scheduler_D_A = scheduler.LambdaLR(self.optimizer_D_A, lr_lambda=lambda_rule) self.scheduler_D_B = scheduler.LambdaLR(self.optimizer_D_B, lr_lambda=lambda_rule) def set_input(self, batch_A, batch_B): self.real_A = batch_A self.real_B = batch_B def forward(self): self.fake_B = self.netG_A(self.real_A) self.rec_A = self.netG_B(self.fake_B) self.fake_A = self.netG_B(self.real_B) self.rec_B = self.netG_A(self.fake_A) def save_images(self, iter_count, batch_size): path = "./datasets/night2day/test_results/test_results_" + str( model_num) + "/" for i in range(batch_size): img_num = (iter_count) * batch_size + i fake_A_numpy = self.fake_A[i].data.cpu().numpy() real_A_numpy = self.real_A[i].data.cpu().numpy() rec_A_numpy = self.rec_A[i].data.cpu().numpy() fake_B_numpy = self.fake_B[i].data.cpu().numpy() real_B_numpy = self.real_B[i].data.cpu().numpy() rec_B_numpy = self.rec_B[i].data.cpu().numpy() image = np.concatenate((fake_A_numpy, real_A_numpy, rec_A_numpy, fake_B_numpy, real_B_numpy, rec_B_numpy), 2) # 2? save_image(torch.from_numpy(image).squeeze() / 2 + 0.5, path + self.name + "_" + str(img_num) + '.png', nrow=batch_size) def backward_D_basic(self, netD, real, fake): pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) pred_fake = netD(fake.detach()) # ! loss_D_fake = self.criterionGAN(pred_fake, False) loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): lambda_idt = self.lambda_identity lambda_A = self.lambda_A lambda_B = self.lambda_B if lambda_idt > 0: self.idt_A = self.netG_A(self.real_B) self.loss_idt_A = self.criterionIdt( self.idt_A, self.real_B) * lambda_B * lambda_idt self.idt_B = self.netG_B(self.real_A) self.loss_idt_B = self.criterionIdt( self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def set_requires_grad(self, nets, requires_grad=False): for net in nets: for param in net.parameters(): param.requires_grad = requires_grad def optimize_parameters(self): self.forward() self.set_requires_grad([self.netD_A, self.netD_B], False) self.optimizer_G_A.zero_grad() self.optimizer_G_B.zero_grad() self.backward_G() self.optimizer_G_A.step() self.optimizer_G_B.step() self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D_A.zero_grad() self.optimizer_D_B.zero_grad() self.backward_D_A() self.backward_D_B() self.optimizer_D_A.step() self.optimizer_D_B.step() self.loss_D = self.loss_D_A + self.loss_D_B def update_learning_rates(self): self.scheduler_G_A.step() self.scheduler_G_B.step() self.scheduler_D_A.step() self.scheduler_D_B.step() def get_current_losses(self): return self.loss_G.item(), self.loss_D.item()
class Network(torch.nn.Module): def __init__(self, n_input_channels=3, n_output_channels=1, n_blocks=9, initial_filters=64, dropout_value=0.25, lr=1e-3, decay=0, decay_epochs=0, batch_size=1, image_width=640, image_height=640, load_network=False, load_epoch=0, model_path='', name='', gpu_ids=[], gan=False, pool_size=50, lambda_gan=1, n_blocks_discr=3): super(Network, self).__init__() self.input_nc = n_input_channels self.output_nc = n_output_channels self.n_blocks = n_blocks self.initial_filters = initial_filters self.dropout_value = dropout_value self.lr = lr self.gpu_ids = gpu_ids self.batch_size = batch_size self.image_width = image_width self.image_height = image_height self.generator = torch.nn.Module() self.discriminator = torch.nn.Module() self.decay = decay self.decay_epochs = decay_epochs self.save_dir = model_path os.makedirs(self.save_dir, exist_ok=True) self.input_img = None self.input_gt = None self.var_img = None self.var_gt = None self.fake_mask = None self.dont_care_mask = None self.criterion_seg = None self.criterion_gan = None self.optimizer_seg = None self.optimizer_dis = None self.fake_mask_pool = None self.loss = None self.loss_seg = None self.loss_g = None self.loss_g_gan = None self.loss_d_gan = None self.gan = gan self.pool_size = pool_size self.lambda_gan = lambda_gan self.n_blocks_discr = n_blocks_discr self.load_network = load_network self.name = name self.load_epoch = load_epoch if len(gpu_ids): self.tensor = torch.cuda.FloatTensor else: self.tensor = torch.FloatTensor self.initialize(n_input_channels, n_output_channels, n_blocks, initial_filters, dropout_value, lr, batch_size, image_width, image_height, gpu_ids, gan, pool_size, n_blocks_discr) def cuda(self): self.generator.cuda() def initialize(self, n_input_channels, n_output_channels, n_blocks, initial_filters, dropout_value, lr, batch_size, image_width, image_height, gpu_ids, gan, pool_size, n_blocks_discr): self.input_img = self.tensor(batch_size, n_input_channels, image_height, image_width) self.input_gt = self.tensor(batch_size, n_output_channels, image_height, image_width) self.generator = uNet(n_input_channels, n_output_channels, n_blocks, initial_filters, dropout_value, gpu_ids) if gan: self.discriminator = ImageDiscriminatorConv( n_output_channels, initial_filters, dropout_value, gpu_ids=gpu_ids, n_blocks=n_blocks_discr) self.criterion_gan = GANLoss(tensor=self.tensor) self.optimizer_dis = torch.optim.Adam( self.discriminator.parameters(), lr=lr, betas=(0.5, 0.999)) self.fake_mask_pool = ImagePool(pool_size) if self.load_network: self._load_network(self.generator, 'Model', self.load_epoch) if gan: self._load_network(self.discriminator, 'Discriminator', self.load_epoch) self.criterion_seg = BinarySelectiveCrossEntropyLoss() self.optimizer_seg = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(0.5, 0.999)) print('---------- Network initialized -------------') self.print_network(self.generator) if gan: self.print_network(self.discriminator) print('-----------------------------------------------') def set_input(self, input_img, input_gt=None): if input_img is not None: self.input_img.resize_(input_img.size()).copy_(input_img) if input_gt is not None: self.input_gt.resize_(input_gt.size()).copy_(input_gt) def forward(self, vol=False): """ Function to create autograd variables of inputs (necessary for back-propagation) :param vol: True if no backprop is needed :return: """ self.var_img = torch.autograd.Variable(self.input_img, volatile=vol) self.var_gt = torch.autograd.Variable(self.input_gt, volatile=vol) def predict(self): """ Function to predict from datasets :return: fakeB: generated image from dataset A to look like images in dataset B :return: recA: reconstructed image from fakeB :return: fakeA: generated image from dataset B to look like images in dataset A :return: recB: reconstructed image from fakeA """ assert (self.input_img is not None) self.var_img = torch.autograd.Variable(self.input_img, volatile=True) self.fake_mask = self.generator.forward(self.var_img) return self.fake_mask def backward_seg(self): self.fake_mask = self.generator.forward(self.var_img) self.loss_seg = self.criterion_seg(self.fake_mask, self.var_gt) self.loss_g = self.loss_seg if self.gan: pred_fake = self.discriminator.forward(self.fake_mask) self.loss_g_gan = self.criterion_gan(pred_fake, True) self.loss_g = self.loss_seg + self.loss_g_gan * self.lambda_gan self.loss_g.backward() def backward_d(self): fake_mask = self.fake_mask_pool.query(self.fake_mask) pred_real = self.discriminator.forward(self.var_gt) loss_d_real = self.criterion_gan(input_tensor=pred_real, target_is_real=True) pred_fake = self.discriminator.forward(fake_mask.detach()) loss_d_fake = self.criterion_gan(input_tensor=pred_fake, target_is_real=False) loss_d = (loss_d_real + loss_d_fake) * 0.5 loss_d.backward() self.loss_d_gan = loss_d def optimize(self): """ Function for parameter optimization :return: None """ self.forward() self.optimizer_seg.zero_grad() self.backward_seg() self.optimizer_seg.step() if self.gan: self.optimizer_dis.zero_grad() self.backward_d() self.optimizer_dis.step() def get_current_errors(self): """ Function to get access to current errors outside class :return: OrderedDict with values different models """ errors = [self.loss_seg.data[0]] labels = ["Seg"] if self.gan: errors.append(self.loss_d_gan.data[0]) errors.append(self.loss_g_gan.data[0]) errors.append(self.loss_g.data[0]) labels.append("Discr") labels.append("Seg_GAN") labels.append("Seg_total") tuple_list = list(zip(labels, errors)) return OrderedDict(tuple_list) def save(self, label): """ Function to save the subnets :param label: label (part of the file the subnet will be saved to) :return: None """ self._save_network(self.generator, 'Model', label, self.gpu_ids) if self.gan: self._save_network(self.discriminator, 'Discriminator', label, self.gpu_ids) def _save_network(self, network, network_label, epoch_label, gpu_ids): """ Helper Function for saving pytorch networks (can be used in subclasses) :param network: the network to save :param network_label: the network label (name) :param epoch_label: the epoch to save :param gpu_ids: the gpu ids to continue training after saving :return: None """ save_filename = str( self.name) + '_%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) if len(gpu_ids) and torch.cuda.is_available(): network.cuda(device_id=gpu_ids[0]) def _load_network(self, network, network_label, epoch_label): """ Helper Function for loading pytorch networks (can be used in subclasses) :param network: the network variable to store the loaded network in :param network_label: part of the filename the network should be loaded from :param epoch_label: the epoch to load :return: None """ save_filename = str( self.name) + '_%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) network.load_state_dict(torch.load(save_path)) def update_learning_rate(self): """ Function for learning rate scheduling :return: None """ tmp = self.lr self.lr -= (self.decay / self.decay_epochs) # for param_group in self.optimizer_d.param_groups: # param_group['lr'] = self.lr for param_group in self.optimizer_seg.param_groups: param_group['lr'] = self.lr if self.gan: for param_group in self.optimizer_dis.param_groups: param_group['lr'] = self.lr print('update learning rate: %f -> %f' % (tmp, self.lr)) @staticmethod def print_network(network): """ Static Helper Function to print a network summary :param network: :return: None """ num_params = 0 for param in network.parameters(): num_params += param.numel() print(network) print('Total number of parameters: %d' % num_params)
class Pix2Pix(nn.Module): def __init__(self, opt): super(Pix2Pix, self).__init__() self.opt = opt self.isTrain = opt.isTrain self.Tensor = torch.cuda.FloatTensor if use_gpu else torch.Tensor self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) # Assuming norm_type = batch norm_layer = functools.partial(nn.BatchNorm2d, affine=True) # model of Generator Net is unet_256 self.GeneratorNet = Generator(opt.input_nc, opt.output_nc, 8, opt.ngf, norm_layer=norm_layer, use_dropout=not opt.no_dropout) if use_gpu: self.GeneratorNet.cuda() self.GeneratorNet.apply(init_weights) if self.isTrain: use_sigmoid = opt.no_lsgan # model of Discriminator Net is basic self.DiscriminatorNet = Discriminator(opt.input_nc + opt.output_nc, opt.ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid) if use_gpu: self.DiscriminatorNet.cuda() self.DiscriminatorNet.apply(init_weights) if not self.isTrain or opt.continue_train: self.load_network(self.GeneratorNet, 'Generator', opt.which_epoch) if self.isTrain: self.load_network(self.DiscriminatorNet, 'Discriminator', opt.which_epoch) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) self.learning_rate = opt.lr # defining loss functions self.criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() self.MySchedulers = [] # initialising schedulers self.MyOptimizers = [] # initialising optimizers self.generator_optimizer = torch.optim.Adam( self.GeneratorNet.parameters(), lr=self.learning_rate, betas=(opt.beta1, 0.999)) self.discriminator_optimizer = torch.optim.Adam( self.DiscriminatorNet.parameters(), lr=self.learning_rate, betas=(opt.beta1, 0.999)) self.MyOptimizers.append(self.generator_optimizer) self.MyOptimizers.append(self.discriminator_optimizer) def lambda_rule(epoch): lr_l = 1.0 - max( 0, epoch - opt.niter) / float(opt.niter_decay + 1) return lr_l for optimizer in self.MyOptimizers: self.MySchedulers.append( lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)) # assuming opt.lr_policy == 'lambda' print('<============ NETWORKS INITIATED ============>') print_net(self.GeneratorNet) if self.isTrain: print_net(self.DiscriminatorNet) print('<=============================================>') def save_network(self, network, network_label, epoch_label): save_path = "./saved_models/%s_net_%s.pth" % (epoch_label, network_label) torch.save(network.cpu().state_dict(), save_path) if use_gpu: network.cuda() def load_network(self, network, network_label, epoch_label): save_path = "./saved_models/%s_net_%s.pth" % (epoch_label, network_label) # torch.save(network.cpu().state_dict(), save_path) network.load_state_dict(torch.load(save_path)) def update_learning_rate(self): for scheduler in self.MySchedulers: scheduler.step() lr = self.MyOptimizers[0].param_groups[0]['lr'] print('learning rate = %.7f' % lr) def set_input(self, input): self.input = input if self.opt.which_direction == 'AtoB': input_A = input['A'] input_B = input['B'] self.image_paths = input['A_paths'] else: input_A = input['B'] input_B = input['A'] self.image_paths = input['B_paths'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) def forward(self): self.real_A = Variable(self.input_A) self.generated_B = self.GeneratorNet.forward(self.real_A) self.real_B = Variable(self.input_B) def get_image_paths(self): return self.image_paths def backward_Discriminator(self): # fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.generated_B), 1)) fake_AB = self.fake_AB_pool.query( torch.cat((self.real_A, self.generated_B), 1)) self.prediction_fake = self.DiscriminatorNet.forward(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(self.prediction_fake, False) real_AB = torch.cat((self.real_A, self.real_B), 1) self.prediction_real = self.DiscriminatorNet.forward(real_AB) self.loss_D_real = self.criterionGAN(self.prediction_real, False) self.loss_Discriminator = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_Discriminator.backward() def backward_Generator(self): fake_AB = torch.cat((self.real_A, self.generated_B), 1) prediction_fake = self.DiscriminatorNet.forward(fake_AB) self.loss_G_GAN = self.criterionGAN(prediction_fake, True) self.loss_G_L1 = self.criterionL1(self.generated_B, self.real_B) * self.opt.lambda_A self.loss_Generator = self.loss_G_GAN + self.loss_G_L1 self.loss_Generator.backward() def optimize_parameters(self): self.forward() self.discriminator_optimizer.zero_grad() self.backward_Discriminator() self.discriminator_optimizer.step() self.generator_optimizer.zero_grad() self.backward_Generator() self.generator_optimizer.step() def get_current_errors(self): return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), ('G_L1', self.loss_G_L1.data[0]), ('D_real', self.loss_D_real.data[0]), ('D_fake', self.loss_D_fake.data[0])]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.generated_B.data) real_B = util.tensor2im(self.real_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) def save(self, label): self.save_network(self.GeneratorNet, 'Generator', label) self.save_network(self.DiscriminatorNet, 'Discriminator', label)
def train(dataset, start_epoch, max_epochs, lr_d, lr_g, batch_size, lmda_cyc, lmda_idt, pool_size, context): mx.random.seed(int(time.time())) print("Loading dataset...", flush=True) training_set_a = load_dataset(dataset, "trainA") training_set_b = load_dataset(dataset, "trainB") gen_ab = ResnetGenerator() dis_b = PatchDiscriminator() gen_ba = ResnetGenerator() dis_a = PatchDiscriminator() bce_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() l1_loss = mx.gluon.loss.L1Loss() gen_ab_params_file = "model/{}.gen_ab.params".format(dataset) dis_b_params_file = "model/{}.dis_b.params".format(dataset) gen_ab_state_file = "model/{}.gen_ab.state".format(dataset) dis_b_state_file = "model/{}.dis_b.state".format(dataset) gen_ba_params_file = "model/{}.gen_ba.params".format(dataset) dis_a_params_file = "model/{}.dis_a.params".format(dataset) gen_ba_state_file = "model/{}.gen_ba.state".format(dataset) dis_a_state_file = "model/{}.dis_a.state".format(dataset) if os.path.isfile(gen_ab_params_file): gen_ab.load_parameters(gen_ab_params_file, ctx=context) else: gen_ab.initialize(GANInitializer(), ctx=context) if os.path.isfile(dis_b_params_file): dis_b.load_parameters(dis_b_params_file, ctx=context) else: dis_b.initialize(GANInitializer(), ctx=context) if os.path.isfile(gen_ba_params_file): gen_ba.load_parameters(gen_ba_params_file, ctx=context) else: gen_ba.initialize(GANInitializer(), ctx=context) if os.path.isfile(dis_a_params_file): dis_a.load_parameters(dis_a_params_file, ctx=context) else: dis_a.initialize(GANInitializer(), ctx=context) print("Learning rate of discriminator:", lr_d, flush=True) print("Learning rate of generator:", lr_g, flush=True) trainer_gen_ab = mx.gluon.Trainer(gen_ab.collect_params(), "Nadam", { "learning_rate": lr_g, "beta1": 0.5 }) trainer_dis_b = mx.gluon.Trainer(dis_b.collect_params(), "Nadam", { "learning_rate": lr_d, "beta1": 0.5 }) trainer_gen_ba = mx.gluon.Trainer(gen_ba.collect_params(), "Nadam", { "learning_rate": lr_g, "beta1": 0.5 }) trainer_dis_a = mx.gluon.Trainer(dis_a.collect_params(), "Nadam", { "learning_rate": lr_d, "beta1": 0.5 }) if os.path.isfile(gen_ab_state_file): trainer_gen_ab.load_states(gen_ab_state_file) if os.path.isfile(dis_b_state_file): trainer_dis_b.load_states(dis_b_state_file) if os.path.isfile(gen_ba_state_file): trainer_gen_ba.load_states(gen_ba_state_file) if os.path.isfile(dis_a_state_file): trainer_dis_a.load_states(dis_a_state_file) fake_a_pool = ImagePool(pool_size) fake_b_pool = ImagePool(pool_size) print("Training...", flush=True) for epoch in range(start_epoch, max_epochs): ts = time.time() random.shuffle(training_set_a) random.shuffle(training_set_b) training_dis_a_L = 0.0 training_dis_b_L = 0.0 training_gen_L = 0.0 training_batch = 0 for real_a, real_b in get_batches(training_set_a, training_set_b, batch_size, ctx=context): training_batch += 1 fake_a, _ = gen_ba(real_b) fake_b, _ = gen_ab(real_a) with mx.autograd.record(): real_a_y, real_a_cam_y = dis_a(real_a) real_a_L = bce_loss(real_a_y, mx.nd.ones_like(real_a_y, ctx=context)) real_a_cam_L = bce_loss( real_a_cam_y, mx.nd.ones_like(real_a_cam_y, ctx=context)) fake_a_y, fake_a_cam_y = dis_a(fake_a_pool.query(fake_a)) fake_a_L = bce_loss(fake_a_y, mx.nd.zeros_like(fake_a_y, ctx=context)) fake_a_cam_L = bce_loss( fake_a_cam_y, mx.nd.zeros_like(fake_a_cam_y, ctx=context)) L = real_a_L + real_a_cam_L + fake_a_L + fake_a_cam_L L.backward() trainer_dis_a.step(batch_size) dis_a_L = mx.nd.mean(L).asscalar() if dis_a_L != dis_a_L: raise ValueError() with mx.autograd.record(): real_b_y, real_b_cam_y = dis_b(real_b) real_b_L = bce_loss(real_b_y, mx.nd.ones_like(real_b_y, ctx=context)) real_b_cam_L = bce_loss( real_b_cam_y, mx.nd.ones_like(real_b_cam_y, ctx=context)) fake_b_y, fake_b_cam_y = dis_b(fake_b_pool.query(fake_b)) fake_b_L = bce_loss(fake_b_y, mx.nd.zeros_like(fake_b_y, ctx=context)) fake_b_cam_L = bce_loss( fake_b_cam_y, mx.nd.zeros_like(fake_b_cam_y, ctx=context)) L = real_b_L + real_b_cam_L + fake_b_L + fake_b_cam_L L.backward() trainer_dis_b.step(batch_size) dis_b_L = mx.nd.mean(L).asscalar() if dis_b_L != dis_b_L: raise ValueError() with mx.autograd.record(): fake_a, gen_a_cam_y = gen_ba(real_b) fake_a_y, fake_a_cam_y = dis_a(fake_a) gan_a_L = bce_loss(fake_a_y, mx.nd.ones_like(fake_a_y, ctx=context)) gan_a_cam_L = bce_loss( fake_a_cam_y, mx.nd.ones_like(fake_a_cam_y, ctx=context)) rec_b, _ = gen_ab(fake_a) cyc_b_L = l1_loss(rec_b, real_b) idt_a, idt_a_cam_y = gen_ba(real_a) idt_a_L = l1_loss(idt_a, real_a) gen_a_cam_L = bce_loss( gen_a_cam_y, mx.nd.ones_like( gen_a_cam_y, ctx=context)) + bce_loss( idt_a_cam_y, mx.nd.zeros_like(idt_a_cam_y, ctx=context)) gen_ba_L = gan_a_L + gan_a_cam_L + cyc_b_L * lmda_cyc + idt_a_L * lmda_cyc * lmda_idt + gen_a_cam_L fake_b, gen_b_cam_y = gen_ab(real_a) fake_b_y, fake_b_cam_y = dis_b(fake_b) gan_b_L = bce_loss(fake_b_y, mx.nd.ones_like(fake_b_y, ctx=context)) gan_b_cam_L = bce_loss( fake_b_cam_y, mx.nd.ones_like(fake_b_cam_y, ctx=context)) rec_a, _ = gen_ba(fake_b) cyc_a_L = l1_loss(rec_a, real_a) idt_b, idt_b_cam_y = gen_ab(real_b) idt_b_L = l1_loss(idt_b, real_b) gen_b_cam_L = bce_loss( gen_b_cam_y, mx.nd.ones_like( gen_b_cam_y, ctx=context)) + bce_loss( idt_b_cam_y, mx.nd.zeros_like(idt_b_cam_y, ctx=context)) gen_ab_L = gan_b_L + gan_b_cam_L + cyc_a_L * lmda_cyc + idt_b_L * lmda_cyc * lmda_idt + gen_b_cam_L L = gen_ba_L + gen_ab_L L.backward() trainer_gen_ba.step(batch_size) trainer_gen_ab.step(batch_size) gen_L = mx.nd.mean(L).asscalar() if gen_L != gen_L: raise ValueError() training_dis_a_L += dis_a_L training_dis_b_L += dis_b_L training_gen_L += gen_L print( "[Epoch %d Batch %d] dis_a_loss %.10f dis_b_loss %.10f gen_loss %.10f elapsed %.2fs" % (epoch, training_batch, dis_a_L, dis_b_L, gen_L, time.time() - ts), flush=True) print( "[Epoch %d] training_dis_a_loss %.10f training_dis_b_loss %.10f training_gen_loss %.10f duration %.2fs" % (epoch + 1, training_dis_a_L / training_batch, training_dis_b_L / training_batch, training_gen_L / training_batch, time.time() - ts), flush=True) gen_ab.save_parameters(gen_ab_params_file) gen_ba.save_parameters(gen_ba_params_file) dis_a.save_parameters(dis_a_params_file) dis_b.save_parameters(dis_b_params_file) trainer_gen_ab.save_states(gen_ab_state_file) trainer_gen_ba.save_states(gen_ba_state_file) trainer_dis_a.save_states(dis_a_state_file) trainer_dis_b.save_states(dis_b_state_file)
def train(self, epochs, batch_size=1, sample_interval=50, pool_size=50): start_time = datetime.datetime.now() # Adversarial loss ground truths valid = np.ones((batch_size, ) + self.disc_patch) fake = np.zeros((batch_size, ) + self.disc_patch) fake_a_pool = ImagePool(pool_size) fake_b_pool = ImagePool(pool_size) tensorboard = TensorBoard(batch_size=batch_size, write_grads=True) tensorboard.set_model(self.combined) def named_logs(model, logs): result = {} for l in zip(model.metrics_names, logs): result[l[0]] = l[1] return result for epoch in range(epochs): for batch_i, (imgs_A, imgs_B) in enumerate( self.data_loader.load_batch(batch_size)): # ---------------------- # Train Discriminators # ---------------------- # Translate images to opposite domain fake_B = fake_b_pool.query(self.g_AB.predict(imgs_A)) fake_A = fake_a_pool.query(self.g_BA.predict(imgs_B)) # Train the discriminators (original images = real / translated = Fake) dA_loss_real = self.d_A.train_on_batch(imgs_A, valid) dA_loss_fake = self.d_A.train_on_batch(fake_A, fake) dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake) dB_loss_real = self.d_B.train_on_batch(imgs_B, valid) dB_loss_fake = self.d_B.train_on_batch(fake_B, fake) dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake) # Total disciminator loss d_loss = 0.5 * np.add(dA_loss, dB_loss) # ------------------ # Train Generators # ------------------ # Train the generators g_loss = self.combined.train_on_batch( [imgs_A, imgs_B], [valid, valid, imgs_A, imgs_B, imgs_A, imgs_B]) elapsed_time = datetime.datetime.now() - start_time # K.clear_session() # Plot the progress print( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \ % (epoch, epochs, batch_i, self.data_loader.n_batches, d_loss[0], 100 * d_loss[1], g_loss[0], np.mean(g_loss[1:3]), np.mean(g_loss[3:5]), np.mean(g_loss[5:6]), elapsed_time)) # If at save interval => save generated image samples if batch_i % sample_interval == 0: self.sample_images(epoch, batch_i) if epoch % 1 == 0: self.combined.save_weights( f"saved_model/{self.dataset_name}/{epoch}.h5")
class BEGANModel(): def __init__(self, opt, gpu_ids=[0], continue_run=None): self.opt = opt self.kt = 0 self.lamk = 0.001 self.lambdaImg = 100 self.lambdaGan = 1.0 self.model_names = ['netD', 'netG'] self.gpu_ids = gpu_ids if not continue_run: expname = '-'.join([ 'b_' + str(self.opt.batchSize), 'ngf_' + str(self.opt.ngf), 'ndf_' + str(self.opt.ndf), 'gm_' + str(self.opt.gamma) ]) self.rundir = self.opt.rundir + '/pix2pixBEGAN-' + datetime.now( ).strftime('%B%d-%H-%M-%S') + expname + self.opt.comment if not os.path.isdir(self.rundir): os.mkdir(self.rundir) with open(self.rundir + '/options.pkl', 'wb') as file: pickle.dump(opt, file) else: self.rundir = continue_run if os.path.isfile(self.rundir + '/options.pkl'): with open(self.rundir + '/options.pkl', 'rb') as file: tmp = opt.rundir tmp_lr = opt.lr self.opt = pickle.load(file) self.opt.rundir = tmp self.opt.lr = tmp_lr self.netG = UnetGenerator(input_nc=3, output_nc=3, num_downs=7, ngf=self.opt.ngf, norm_layer=nn.BatchNorm2d, use_dropout=True) self.netD = UnetDescriminator(input_nc=3, output_nc=3, num_downs=7, ngf=self.opt.ndf, norm_layer=nn.BatchNorm2d, use_dropout=True) # Decide which device we want to run on self.device = torch.device("cuda:0" if ( torch.cuda.is_available()) else "cpu") init_net(self.netG, 'normal', 0.002, [0]) init_net(self.netD, 'normal', 0.002, [0]) self.netG.to(self.device) self.netD.to(self.device) self.imagePool = ImagePool(pool_size) self.criterionL1 = torch.nn.L1Loss() if continue_run: self.load_networks('latest') self.writer = Logger(self.rundir) self.start_step, self.opt.lr = self.writer.get_latest( 'misc/lr', self.opt.lr) # initialize optimizers self.optimG = torch.optim.Adam(self.netG.parameters(), lr=self.opt.lr, betas=(beta1, 0.999)) self.optimD = torch.optim.Adam(self.netD.parameters(), lr=self.opt.lr, betas=(beta1, 0.999)) def set_input(self, data): self.real_A = data['A'].to(self.device) self.real_B = data['B'].to(self.device) def forward(self): self.fake_B = self.netG(self.real_A) def backward_D(self): for p in self.netD.parameters(): p.requires_grad = True self.optimD.zero_grad() fake = self.imagePool.query(self.fake_B.detach()) recon_real_B = self.netD(self.real_B) recon_fake = self.netD(fake) d_real = torch.mean(torch.abs(recon_real_B - self.real_B)) d_fake = torch.mean(torch.abs(recon_fake - fake)) L_D = d_real - self.kt * d_fake L_D.backward() self.optimD.step() self.L_D_val = L_D.item() self.d_fake_cpu = d_fake.detach().cpu().item() self.d_real_cpu = d_real.detach().cpu().item() self.recon_real_B_cpu = recon_real_B.detach().cpu() self.recon_fake_cpu = recon_fake.detach().cpu() self.fake_cpu = fake.detach().cpu() def backward_G(self): for p in self.netD.parameters(): p.requires_grad = False self.optimG.zero_grad() L_Img = self.lambdaImg * self.criterionL1(self.fake_B, self.real_B) L_Img.backward(retain_graph=True) recon_fake_B = self.netD(self.fake_B) self.L_G_fake = self.lambdaGan * torch.mean( torch.abs(recon_fake_B - self.fake_B)) if self.lambdaGan > 0: self.L_G_fake.backward() self.optimG.step() self.L_Img_cpu = L_Img.detach().cpu() self.L_G_fake_cpu = self.L_G_fake.detach().cpu() def update_K(self): balance = self.opt.gamma * self.d_real_cpu - self.d_fake_cpu self.kt = min(max(self.kt + self.lamk * balance, 0), 1) self.M_global = self.d_real_cpu + np.abs(balance) def updatelr(self): self.opt.lr = self.opt.lr / 2 for param_group in self.optimD.param_groups: param_group['lr'] = self.opt.lr # param_group['lr']/2 for param_group in self.optimG.param_groups: param_group['lr'] = self.opt.lr # param_group['lr']/2 def log(self, epoch, batchn, n_iter): print('Writing summaries....') self.writer.scalar_summary('misc/M_global', self.M_global, n_iter) self.writer.scalar_summary('misc/kt', self.kt, n_iter) self.writer.scalar_summary('misc/lr', self.opt.lr, n_iter) self.writer.scalar_summary('loss/L_D', self.L_D_val, n_iter) self.writer.scalar_summary('loss/d_real', self.d_real_cpu, n_iter) self.writer.scalar_summary('loss/d_fake', self.d_fake_cpu, n_iter) self.writer.scalar_summary('loss/L_G', self.L_G_fake_cpu, n_iter) self.writer.scalar_summary('loss/L1', self.L_Img_cpu, n_iter) test_A = self.test_data['A'] test_B = self.test_data['B'] val_A = self.val_data['A'] with torch.no_grad(): fake_test_B = self.netG(test_A.to(self.device)) fake_val_B = self.netG(val_A.to(self.device)) images = torch.cat([test_A, test_B, fake_test_B.cpu()]) x = vutils.make_grid(images / 2 + 0.5, normalize=True, scale_each=True, nrow=4) self.writer.image_summary('Test/Fixed', [x], n_iter) images = torch.cat([ self.real_A.detach().cpu(), self.real_B.cpu(), self.fake_B.detach().cpu() ]) x = vutils.make_grid(images / 2 + 0.5, normalize=True, scale_each=True, nrow=4) self.writer.image_summary('Test/Last', [x], n_iter) images = torch.cat([val_A, fake_val_B.cpu()]) x = vutils.make_grid(images / 2 + 0.5, normalize=True, scale_each=True, nrow=4) self.writer.image_summary('Test/Validation', [x], n_iter) images = torch.cat([self.real_B.cpu(), self.recon_real_B_cpu]) x = vutils.make_grid(images / 2 + 0.5, normalize=True, scale_each=True, nrow=4) self.writer.image_summary('Discriminator/Recon_Real', [x], n_iter) images = torch.cat([self.fake_cpu, self.recon_fake_cpu]) x = vutils.make_grid(images / 2 + 0.5, normalize=True, scale_each=True, nrow=4) self.writer.image_summary('Discriminator/Recon_Fake', [x], n_iter) self.save_networks(epoch) for name, param in self.netG.named_parameters(): if 'bn' in name: continue self.writer.histo_summary('weight_G/' + name, param.clone().cpu().data.numpy(), n_iter) self.writer.histo_summary('grad_G/' + name, param.grad.clone().cpu().data.numpy(), n_iter) for name, param in self.netD.named_parameters(): if 'bn' in name: continue self.writer.histo_summary('weight_D/' + name, param.clone().cpu().data.numpy(), n_iter) self.writer.histo_summary('grad_D/' + name, param.grad.clone().cpu().data.numpy(), n_iter) def set_test_input(self, test_data): self.test_data = test_data def set_val_input(self, val_data): self.val_data = val_data # save models to the disk def save_networks(self, epoch): for name in self.model_names: if isinstance(name, str): save_filename = '{}_{}.pth'.format(epoch, name) save_path = os.path.join(self.rundir, save_filename) net = getattr(self, name) if len(self.gpu_ids) > 0 and torch.cuda.is_available(): torch.save(net.cpu().state_dict(), save_path) net.cuda(self.gpu_ids[0]) else: torch.save(net.cpu().state_dict(), save_path) self.update_link( save_path, os.path.join(self.rundir, 'latest_{}.pth'.format(name))) def update_link(self, src, dst): shutil.copy2(src, dst) # load models from the disk def load_networks(self, epoch): for name in self.model_names: if isinstance(name, str): load_filename = '%s_%s.pth' % (epoch, name) load_path = os.path.join(self.rundir, load_filename) if not os.path.isfile(load_path): return net = getattr(self, name) if isinstance(net, torch.nn.DataParallel): net = net.module print('loading the model from %s' % load_path) state_dict = torch.load(load_path, map_location=self.device) if hasattr(state_dict, '_metadata'): del state_dict._metadata net.load_state_dict(state_dict)
class GAN(nn.Module): def __init__(self, lambda_ABA=settings.lambda_ABA, lambda_BAB=settings.lambda_BAB, lambda_local=settings.lambda_local, pool_size=settings.pool_size, max_crop_side=settings.max_crop_side, decay_start=settings.decay_start, epochs_to_zero_lr=settings.epochs_to_zero_lr, warm_epochs=settings.warmup_epochs): super(GAN, self).__init__() self.r = 0 self.lambda_ABA = lambda_ABA self.lambda_BAB = lambda_BAB self.lambda_local = lambda_local self.max_crop_side = max_crop_side self.netG_A = Generator(input_nc=4, output_nc=3) self.netG_B = Generator(input_nc=4, output_nc=3) self.netD_A = NLayerDiscriminator(input_nc=3) self.netD_B = NLayerDiscriminator(input_nc=3) self.localD = NLayerDiscriminator(input_nc=3) self.crop_drones = CropDrones() self.criterionGAN = GANLoss("lsgan") self.criterionCycle = nn.L1Loss() init_weights(self.netG_A) init_weights(self.netG_B) init_weights(self.netD_A) init_weights(self.netD_B) init_weights(self.localD) self.fake_B_pool = ImagePool(pool_size) self.fake_A_pool = ImagePool(pool_size) self.fake_drones_pool = ImagePool(pool_size) def get_inputs(self, input_): self.real_A_with_windows = torch.as_tensor(input_['A'], device=self.device) self.real_B_with_windows = torch.as_tensor(input_['B'], device=self.device) self.real_A = self.real_A_with_windows[:, :-1] self.real_B = self.real_B_with_windows[:, :-1] self.A_windows = self.real_A_with_windows[:, -1:] self.B_windows = self.real_B_with_windows[:, -1:] self.real_drones = torch.zeros(self.real_B.shape[0], 3, self.max_crop_side, self.max_crop_side, device=self.device) self.fake_drones = torch.zeros(self.real_A.shape[0], 3, self.max_crop_side, self.max_crop_side, device=self.device) def forward(self, input_): self.get_inputs(input_) self.fake_A = self.netG_A(self.real_B_with_windows) self.rest_B = self.netG_B( torch.cat([self.fake_A, self.B_windows], dim=1)) self.real_drones = self.crop_drones( (self.real_B_with_windows, self.real_drones)) self.fake_B = self.netG_B(self.real_A_with_windows) self.rest_A = self.netG_A( torch.cat([self.fake_B, self.A_windows], dim=1)) self.fake_drones = self.crop_drones( (torch.cat([self.fake_B, self.A_windows], dim=1), self.fake_drones)) def update_learning_rate(self): for scheduler in self.schedulers: scheduler.step() def iteration(self, input_): self.forward(input_) loss_dict = dict() # backward for D_A real_output_D_A = self.netD_A(self.real_A) real_GAN_loss_D_A = self.criterionGAN(real_output_D_A, True) fake_A = self.fake_B_pool.query(self.fake_A) fake_output_D_A = self.netD_A(fake_A.detach()) fake_GAN_loss_D_A = self.criterionGAN(fake_output_D_A, False) D_A_loss = (real_GAN_loss_D_A + fake_GAN_loss_D_A) * 0.5 loss_dict['D_A'] = D_A_loss # backward for D_B real_output_D_B = self.netD_B(self.real_B) real_GAN_loss_D_B = self.criterionGAN(real_output_D_B, True) fake_B = self.fake_B_pool.query(self.fake_B) fake_output_D_B = self.netD_B(fake_B.detach()) fake_GAN_loss_D_B = self.criterionGAN(fake_output_D_B, False) D_B_loss = (real_GAN_loss_D_B + fake_GAN_loss_D_B) * 0.5 loss_dict['D_B'] = D_B_loss # backward for localD real_output_localD = self.localD(self.real_drones) real_GAN_loss_localD = self.criterionGAN(real_output_localD, True) fake_drones = self.fake_drones_pool.query(self.fake_drones) fake_output_localD = self.localD(fake_drones.detach()) fake_GAN_loss_localD = self.criterionGAN(fake_output_localD, False) localD_loss = (real_GAN_loss_localD + fake_GAN_loss_localD) * 0.5 loss_dict['local_D'] = localD_loss # backward for G_A and G_B G_A_GAN_loss = self.criterionGAN(self.netD_A(self.fake_A), True) BAB_cycle_loss = self.criterionCycle(self.real_B, self.rest_B) G_B_GAN_loss = self.criterionGAN(self.netD_B(self.fake_B), True) G_B_local_loss = self.criterionGAN(self.localD(self.fake_drones), True) ABA_cycle_loss = self.criterionCycle(self.real_A, self.rest_A) G_loss = G_B_GAN_loss + G_A_GAN_loss + G_B_local_loss * self.lambda_local + ABA_cycle_loss *\ self.lambda_ABA * self.r + BAB_cycle_loss * self.lambda_BAB * self.r loss_dict['G_B'] = G_B_GAN_loss loss_dict['G_A'] = G_A_GAN_loss loss_dict['G_local'] = G_B_local_loss loss_dict['G'] = G_loss return loss_dict
class Pix2PixModel(object): def __init__( self, name="experiment", phase="train", which_epoch="latest", batch_size=1, image_size=128, map_nc=1, input_nc=3, output_nc=3, num_downs=7, ngf=64, ndf=64, norm_layer="batch", pool_size=50, lr=0.0002, beta1=0.5, lambda_D=0.5, lambda_MSE=10, lambda_P=5.0, use_dropout=True, gpu_ids=[], n_layers=3, use_sigmoid=False, use_lsgan=True, upsampling="nearest", continue_train=False, checkpoints_dir="checkpoints/" ): # Define input data that will be consumed by networks self.input_A = torch.FloatTensor( batch_size, 3, image_size, image_size ) self.input_map = torch.FloatTensor( batch_size, map_nc, image_size, image_size ) norm_layer = nn.BatchNorm2d \ if norm_layer == "batch" else nn.InstanceNorm2d # Define netD and netG self.netG = networks.UnetGenerator( input_nc=input_nc, output_nc=map_nc, num_downs=num_downs, ngf=ngf, use_dropout=use_dropout, gpu_ids=gpu_ids, norm_layer=norm_layer, upsampling_layer=upsampling ) self.netD = networks.NLayerDiscriminator( input_nc=input_nc + map_nc, ndf=ndf, n_layers=n_layers, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids ) # Transfer data to GPU if len(gpu_ids) > 0: self.input_A = self.input_A.cuda() self.input_map = self.input_map.cuda() self.netD.cuda() self.netG.cuda() # Initialize parameters of netD and netG self.netG.apply(networks.weights_init) self.netD.apply(networks.weights_init) # Load trained netD and netG if phase == "test" or continue_train: netG_checkpoint_file = os.path.join( checkpoints_dir, name, "netG_{}.pth".format(which_epoch) ) self.netG.load_state_dict( torch.load(netG_checkpoint_file) ) print("Restoring netG from {}".format(netG_checkpoint_file)) if continue_train: netD_checkpoint_file = os.path.join( checkpoints_dir, name, "netD_{}.pth".format(which_epoch) ) self.netD.load_state_dict( torch.load(netD_checkpoint_file) ) print("Restoring netD from {}".format(netD_checkpoint_file)) self.name = name self.gpu_ids = gpu_ids self.checkpoints_dir = checkpoints_dir # Criterions if phase == "train": self.count = 0 self.lr = lr self.lambda_D = lambda_D self.lambda_MSE = lambda_MSE self.image_pool = ImagePool(pool_size) self.criterionGAN = networks.GANLoss(use_lsgan=use_lsgan) self.criterionL1 = torch.nn.L1Loss() self.criterionMSE = torch.nn.MSELoss() # Landmark loss self.optimizer_G = torch.optim.Adam( self.netG.parameters(), lr=self.lr, betas=(beta1, 0.999) ) self.optimizer_D = torch.optim.Adam( self.netD.parameters(), lr=self.lr, betas=(beta1, 0.999) ) print('---------- Networks initialized -------------') networks.print_network(self.netG) networks.print_network(self.netD) print('-----------------------------------------------') def set_input(self, input_A, input_map, input_name): self.input_A.resize_(input_A.size()).copy_(input_A) self.input_map.resize_(input_map.size()).copy_(input_map) self.input_name = input_name def get_image_paths(self): return self.input_name[0] def forward(self): self.real_A = Variable(self.input_A) self.fake_map = self.netG.forward(self.real_A) self.real_map = Variable(self.input_map) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_map = self.netG.forward(self.real_A) self.real_map = Variable(self.input_map, volatile=True) def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_Amap = self.image_pool.query( torch.cat((self.real_A, self.fake_map), 1) ) self.pred_fake = self.netD.forward(fake_Amap.detach()) self.loss_D_fake = self.criterionGAN(self.pred_fake, False) # Real real_Amap = torch.cat((self.real_A, self.real_map), 1) self.pred_real = self.netD.forward(real_Amap) self.loss_D_real = self.criterionGAN(self.pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * self.lambda_D self.loss_D.backward() def backward_G(self): # Third, G(A)_map = map self.loss_G_MSE = self.criterionMSE( self.fake_map, self.real_map ) * self.lambda_MSE fake_Amap = torch.cat( (self.real_A, self.fake_map), 1 ) pred_fake = self.netD.forward(fake_Amap) self.loss_G_GAN = self.criterionGAN(pred_fake, True) self.loss_G = self.loss_G_GAN + self.loss_G_MSE self.loss_G.backward() def optimize_parameters(self): self.forward() self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self): return OrderedDict( [ ('G_GAN', self.loss_G_GAN.data[0]), ('G_MSE', self.loss_G_MSE.data[0]), ('D_real', self.loss_D_real.data[0]), ('D_fake', self.loss_D_fake.data[0]) ] ) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_map = util.tensor2im(self.fake_map.data) real_map = util.tensor2im(self.real_map.data) return OrderedDict( [ ('real_A', real_A), ('fake_map', fake_map), ('real_map', real_map) ] ) def save(self, which_epoch): netD_path = os.path.join( self.checkpoints_dir, self.name, "netD_{}.pth".format(which_epoch) ) netG_path = os.path.join( self.checkpoints_dir, self.name, "netG_{}.pth".format(which_epoch) ) torch.save(self.netD.cpu().state_dict(), netD_path) torch.save(self.netG.cpu().state_dict(), netG_path) if len(self.gpu_ids) > 0: self.netG.cuda() self.netD.cuda() def update_learning_rate(self, decay): old_lr = self.lr self.lr = self.lr * decay for param_group in self.optimizer_D.param_groups: param_group['lr'] = self.lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = self.lr print('update learning rate: %f -> %f' % (old_lr, self.lr))