class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['new_size'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] self.reg_param = hyperparameters['reg_param'] self.beta_step = hyperparameters['beta_step'] self.target_kl = hyperparameters['target_kl'] self.gan_type = hyperparameters['gan_type'] # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_b.parameters()) gen_params = list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.gen_b.apply(weights_init(hyperparameters['init'])) self.dis_b.apply(weights_init('gaussian')) # SSIM Loss self.ssim_loss = pytorch_ssim.SSIM() def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def recon_criterion_l1(self, input, target, mask): return torch.sum(torch.abs(input - target)) / torch.sum(mask) def forward(self, x_a, x_b): self.eval() s_b = self.gen_b.enc_style(x_b) c_a = self.gen_b.enc_content(x_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab def gen_update(self, x_a, x_b, hyperparameters): toogle_grad(self.dis_b, False) toogle_grad(self.gen_b, True) self.dis_b.train() self.gen_b.train() self.gen_opt.zero_grad() s_b = torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda() # encode c_a = self.gen_b.enc_content(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode x_ab = self.gen_b.decode(c_a, s_b) # encode again c_a_recon, s_b_recon = self.gen_b.encode(x_ab) x_ab.requires_grad_() # reconstruction loss self.loss_gen_recon_x_ab_ssim = -self.ssim_loss.forward(x_a, x_ab) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) # GAN loss _, _, d_fake = self.dis_b(x_ab) # d_fake = d_fake['out'] self.loss_gen_adv_b = self.compute_loss(d_fake, 1) # total loss self.loss_gen_total = self.loss_gen_adv_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_x_ab'] * self.loss_gen_recon_x_ab_ssim self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() x_ab = [] s_b = self.gen_b.enc_style(x_b) for i in range(x_a.size(0)): c_a = self.gen_b.enc_content(x_a[i].unsqueeze(0)) x_ab.append(self.gen_b.decode(c_a, s_b)) x_ab = torch.cat(x_ab) self.train() return x_a, x_ab def dis_update(self, x_a, x_b, hyperparameters): toogle_grad(self.gen_b, False) toogle_grad(self.dis_b, True) self.gen_b.train() self.dis_b.train() self.dis_opt.zero_grad() # On real data x_b.requires_grad_() d_real_dict = self.dis_b(x_b) d_real = d_real_dict[2] dloss_real = self.compute_loss(d_real, 1) reg = 0. # Both grad penal and vgan! dloss_real.backward(retain_graph=True) # hard coded 10 weight for grad penal. reg += 10. * compute_grad2(d_real, x_b).mean() mu = d_real_dict[0] logstd = d_real_dict[1] kl_real = kl_loss(mu, logstd).mean() # On fake data with torch.no_grad(): s_b = torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda() c_a = self.gen_b.enc_content(x_a) x_ab = self.gen_b.decode(c_a, s_b) x_ab.requires_grad_() d_fake_dict = self.dis_b(x_ab) d_fake = d_fake_dict[2] dloss_fake = self.compute_loss(d_fake, 0) dloss_fake.backward(retain_graph=True) mu_fake = d_fake_dict[0] logstd_fake = d_fake_dict[1] kl_fake = kl_loss(mu_fake, logstd_fake).mean() avg_kl = 0.5 * (kl_real + kl_fake) reg += self.reg_param * avg_kl reg.backward() self.update_beta(avg_kl) self.dis_opt.step() self.loss_dis_total = (dloss_real + dloss_fake) return self.loss_dis_total.item() def compute_loss(self, d_out, target): targets = d_out.new_full(size=d_out.size(), fill_value=target) if self.gan_type == 'standard': loss = F.binary_cross_entropy_with_logits(d_out, targets) elif self.gan_type == 'wgan': loss = (2 * target - 1) * d_out.mean() else: raise NotImplementedError return loss def update_beta(self, avg_kl): with torch.no_grad(): new_beta = self.reg_param - self.beta_step * \ (self.target_kl - avg_kl) # self.target_kl is constrain I_c, new_beta = max(new_beta, 0) # print('setting beta from %.2f to %.2f' % (self.reg_param, new_beta)) self.reg_param = new_beta def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % iterations) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % iterations) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'b': self.gen_b.state_dict()}, gen_name) torch.save({'b': self.dis_b.state_dict()}, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters, resume_epoch=-1, snapshot_dir=None): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks. self.gen = AdaINGen( hyperparameters['input_dim'] + hyperparameters['n_datasets'], hyperparameters['gen'], hyperparameters['n_datasets']) # Auto-encoder for domain a. self.dis = MsImageDis( hyperparameters['input_dim'] + hyperparameters['n_datasets'], hyperparameters['dis']) # Discriminator for domain a. self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] self.beta1 = hyperparameters['beta1'] self.beta2 = hyperparameters['beta2'] self.weight_decay = hyperparameters['weight_decay'] # Initiating and loader pretrained UNet. self.sup = UNet(input_channels=hyperparameters['input_dim'], num_classes=2).cuda() # Fix the noise used in sampling. self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda() # Setup the optimizers. beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis.parameters()) gen_params = list(self.gen.parameters()) + list(self.sup.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(self.beta1, self.beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(self.beta1, self.beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization. self.apply(weights_init(hyperparameters['init'])) self.dis.apply(weights_init('gaussian')) # Presetting one hot encoding vectors. self.one_hot_img = torch.zeros(hyperparameters['n_datasets'], hyperparameters['batch_size'], hyperparameters['n_datasets'], 256, 256).cuda() self.one_hot_c = torch.zeros(hyperparameters['n_datasets'], hyperparameters['batch_size'], hyperparameters['n_datasets'], 64, 64).cuda() for i in range(hyperparameters['n_datasets']): self.one_hot_img[i, :, i, :, :].fill_(1) self.one_hot_c[i, :, i, :, :].fill_(1) if resume_epoch != -1: self.resume(snapshot_dir, hyperparameters) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def semi_criterion(self, input, target): loss = CrossEntropyLoss2d(size_average=False).cuda() return loss(input, target) def forward(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True s_a = Variable(self.s_a, volatile=True) s_b = Variable(self.s_b, volatile=True) one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) c_a, s_a_fake = self.gen.encode(one_hot_x_a) c_b, s_b_fake = self.gen.encode(one_hot_x_b) one_hot_c_b = torch.cat([c_b, self.one_hot_c[d_index_a]], 1) one_hot_c_a = torch.cat([c_a, self.one_hot_c[d_index_b]], 1) x_ba = self.gen.decode(one_hot_c_b, s_a) x_ab = self.gen.decode(one_hot_c_a, s_b) self.train() return x_ab, x_ba def set_gen_trainable(self, train_bool): if train_bool: self.gen.train() for param in self.gen.parameters(): param.requires_grad = True else: self.gen.eval() for param in self.gen.parameters(): param.requires_grad = True def set_sup_trainable(self, train_bool): if train_bool: self.sup.train() for param in self.sup.parameters(): param.requires_grad = True else: self.sup.eval() for param in self.sup.parameters(): param.requires_grad = True def sup_update(self, x_a, x_b, y_a, y_b, d_index_a, d_index_b, use_a, use_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) # Encode. c_a, s_a_prime = self.gen.encode(one_hot_x_a) c_b, s_b_prime = self.gen.encode(one_hot_x_b) # Decode (within domain). one_hot_c_a = torch.cat([c_a, self.one_hot_c[d_index_a]], 1) one_hot_c_b = torch.cat([c_b, self.one_hot_c[d_index_b]], 1) x_a_recon = self.gen.decode(one_hot_c_a, s_a_prime) x_b_recon = self.gen.decode(one_hot_c_b, s_b_prime) # Decode (cross domain). one_hot_c_ab = torch.cat([c_a, self.one_hot_c[d_index_b]], 1) one_hot_c_ba = torch.cat([c_b, self.one_hot_c[d_index_a]], 1) x_ba = self.gen.decode(one_hot_c_ba, s_a) x_ab = self.gen.decode(one_hot_c_ab, s_b) # Encode again. one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1) one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1) c_b_recon, s_a_recon = self.gen.encode(one_hot_x_ba) c_a_recon, s_b_recon = self.gen.encode(one_hot_x_ab) # Forwarding through supervised model. p_a = None p_b = None loss_semi_a = None loss_semi_b = None has_a_label = (c_a[use_a, :, :, :].size(0) != 0) if has_a_label: p_a = self.sup(c_a, use_a, True) p_a_recon = self.sup(c_a_recon, use_a, True) loss_semi_a = self.semi_criterion(p_a, y_a[use_a, :, :]) + \ self.semi_criterion(p_a_recon, y_a[use_a, :, :]) has_b_label = (c_b[use_b, :, :, :].size(0) != 0) if has_b_label: p_b = self.sup(c_b, use_b, True) p_b_recon = self.sup(c_b, use_b, True) loss_semi_b = self.semi_criterion(p_b, y_b[use_b, :, :]) + \ self.semi_criterion(p_b_recon, y_b[use_b, :, :]) self.loss_gen_total = None if loss_semi_a is not None and loss_semi_b is not None: self.loss_gen_total = loss_semi_a + loss_semi_b elif loss_semi_a is not None: self.loss_gen_total = loss_semi_a elif loss_semi_b is not None: self.loss_gen_total = loss_semi_b if self.loss_gen_total is not None: self.loss_gen_total.backward() self.gen_opt.step() def sup_forward(self, x, y, d_index, hyperparameters): self.sup.eval() # Encoding content image. one_hot_x = torch.cat([x, self.one_hot_img[d_index, 0].unsqueeze(0)], 1) content, _ = self.gen.encode(one_hot_x) # Forwarding on supervised model. y_pred = self.sup(content, only_prediction=True) # Computing metrics. pred = y_pred.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy() jacc = jaccard(pred, y.cpu().squeeze(0).numpy()) return jacc, pred, content def gen_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # Encode. one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) c_a, s_a_prime = self.gen.encode(one_hot_x_a) c_b, s_b_prime = self.gen.encode(one_hot_x_b) # Decode (within domain). one_hot_c_a = torch.cat([c_a, self.one_hot_c[d_index_a]], 1) one_hot_c_b = torch.cat([c_b, self.one_hot_c[d_index_b]], 1) x_a_recon = self.gen.decode(one_hot_c_a, s_a_prime) x_b_recon = self.gen.decode(one_hot_c_b, s_b_prime) # Decode (cross domain). one_hot_c_ab = torch.cat([c_a, self.one_hot_c[d_index_b]], 1) one_hot_c_ba = torch.cat([c_b, self.one_hot_c[d_index_a]], 1) x_ba = self.gen.decode(one_hot_c_ba, s_a) x_ab = self.gen.decode(one_hot_c_ab, s_b) # Encode again. one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1) one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1) c_b_recon, s_a_recon = self.gen.encode(one_hot_x_ba) c_a_recon, s_b_recon = self.gen.encode(one_hot_x_ab) # Decode again (if needed). one_hot_c_aba_recon = torch.cat([c_a_recon, self.one_hot_c[d_index_a]], 1) one_hot_c_bab_recon = torch.cat([c_b_recon, self.one_hot_c[d_index_b]], 1) x_aba = self.gen.decode(one_hot_c_aba_recon, s_a_prime) x_bab = self.gen.decode(one_hot_c_bab_recon, s_b_prime) # Reconstruction loss. self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) # GAN loss. self.loss_gen_adv_a = self.dis.calc_gen_loss(one_hot_x_ba) self.loss_gen_adv_b = self.dis.calc_gen_loss(one_hot_x_ab) # Total loss. self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True s_a1 = Variable(self.s_a, volatile=True) s_b1 = Variable(self.s_b, volatile=True) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(), volatile=True) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(), volatile=True) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): one_hot_x_a = torch.cat( [x_a[i].unsqueeze(0), self.one_hot_img_a[i].unsqueeze(0)], 1) one_hot_x_b = torch.cat( [x_b[i].unsqueeze(0), self.one_hot_img_b[i].unsqueeze(0)], 1) c_a, s_a_fake = self.gen.encode(one_hot_x_a) c_b, s_b_fake = self.gen.encode(one_hot_x_b) x_a_recon.append(self.gen.decode(c_a, s_a_fake)) x_b_recon.append(self.gen.decode(c_b, s_b_fake)) x_ba1.append(self.gen.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # Encode. one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) c_a, _ = self.gen.encode(one_hot_x_a) c_b, _ = self.gen.encode(one_hot_x_b) one_hot_c_ba = torch.cat([c_b, self.one_hot_c[d_index_a]], 1) one_hot_c_ab = torch.cat([c_a, self.one_hot_c[d_index_b]], 1) # Decode (cross domain). x_ba = self.gen.decode(one_hot_c_ba, s_a) x_ab = self.gen.decode(one_hot_c_ab, s_b) # D loss. one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1) one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1) self.loss_dis_a = self.dis.calc_dis_loss(one_hot_x_ba, one_hot_x_a) self.loss_dis_b = self.dis.calc_dis_loss(one_hot_x_ab, one_hot_x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \ hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): print("--> " + checkpoint_dir) # Load generator. last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen.load_state_dict(state_dict) epochs = int(last_model_name[-11:-3]) # Load supervised model. last_model_name = get_model_list(checkpoint_dir, "sup") state_dict = torch.load(last_model_name) self.sup.load_state_dict(state_dict) # Load discriminator. last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis.load_state_dict(state_dict) # Load optimizers. last_model_name = get_model_list(checkpoint_dir, "opt") state_dict = torch.load(last_model_name) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) for state in self.dis_opt.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() for state in self.gen_opt.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() # Reinitilize schedulers. self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, epochs) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, epochs) print('Resume from epoch %d' % epochs) return epochs def save(self, snapshot_dir, epoch): # Save generators, discriminators, and optimizers. gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % epoch) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % epoch) sup_name = os.path.join(snapshot_dir, 'sup_%08d.pt' % epoch) opt_name = os.path.join(snapshot_dir, 'opt_%08d.pt' % epoch) torch.save(self.gen.state_dict(), gen_name) torch.save(self.dis.state_dict(), dis_name) torch.save(self.sup.state_dict(), sup_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)