def forward(self, imgs, sentence=None, word_emb=None, train_perm=True): feat_64 = self.from_64(imgs[0]) #feat_128 = self.from_128(imgs[1]) logit_1 = self.rf_64(feat_64).view(-1) #logit_3 = self.rf_128(feat_128).view(-1) b = imgs[0].shape[0] #img_feat = torch.cat([feat_64, feat_128], dim=1) img_feat = feat_64 match = match_perm = match_word = match_perm_word = None #match = self.it_match(img_feat, sentence) #match = self.matcher_2(torch.cat([self.matcher_1(img_feat), sentence.unsqueeze(-1).unsqueeze(-1)],dim=1).squeeze(-1).squeeze(-1)) #b,256 match_word = self.wordLevelDis(img_feat, word_emb) if train_perm: perm = true_randperm(b) #match_perm = self.it_match(img_feat[perm], sentence) #match_perm = self.matcher_2(torch.cat([self.matcher_1(img_feat[perm]), sentence.unsqueeze(-1).unsqueeze(-1)],dim=1).squeeze(-1).squeeze(-1)) #b,256 match_perm_word = self.wordLevelDis(img_feat[perm], word_emb) return torch.cat([logit_1 ]), match, match_perm, match_word, match_perm_word
def forward(self, imgs, sentence=None, word_emb=None, train_perm=True): feat_64 = self.from_64(imgs[0]) feat_256 = self.from_256(imgs[1]) logit_1 = self.rf_64(feat_64).view(-1) logit_3 = self.rf_256(feat_256).view(-1) b = imgs[0].shape[0] img_feat = torch.cat([feat_64, feat_256], dim=1) match = pred_text = match_perm = pred_text_perm = None #match = self.it_match(img_feat, sentence) match = self.matcher( torch.cat([ img_feat, sentence.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 8, 8) ], dim=1)) #b,256 if train_perm: perm = true_randperm(b) #match_perm = self.it_match(img_feat[perm], sentence) match_perm = self.matcher( torch.cat([ img_feat[perm], sentence.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 8, 8) ], dim=1)) #b,256 return torch.cat([logit_1, logit_3]), match, match_perm
def loss_for_list_perm(loss, fl1, fl2, detach_second=True): result_loss = 0 for f_idx in range(len(fl1)): perm = true_randperm(fl1[0].shape[0], fl1[0].device) if detach_second: result_loss += F.relu(2 + loss(fl1[f_idx], fl2[f_idx].detach()) - loss(fl1[f_idx][perm], fl2[f_idx].detach())) else: result_loss += F.relu(2 + loss(fl1[f_idx], fl2[f_idx]) - loss(fl1[f_idx][perm], fl2[f_idx])) return result_loss
def image_generator_perm(dataset, net_ae, net_ig, BATCH_SIZE=8, n_batches=500): counter = 0 dataloader = iter( DataLoader(dataset, BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=False)) n_batches = min(n_batches, len(dataset) // BATCH_SIZE - 1) while counter < n_batches: counter += 1 rgb_img, _, _, skt_img = next(dataloader) rgb_img = F.interpolate(rgb_img, size=512).cuda() skt_img = F.interpolate(skt_img, size=512).cuda() perm = true_randperm(rgb_img.shape[0], device=rgb_img.device) gimg_ae, style_feat = net_ae(skt_img, rgb_img[perm]) g_image = net_ig(gimg_ae, style_feat) if counter == 1: vutils.save_image(0.5 * (g_image + 1), 'tmp.jpg') yield g_image
def train(): from benchmark import calc_fid, extract_feature_from_generator_fn, load_patched_inception_v3, real_image_loader, image_generator, image_generator_perm import lpips from config import IM_SIZE_GAN, BATCH_SIZE_GAN, NFC, NBR_CLS, DATALOADER_WORKERS, EPOCH_GAN, ITERATION_AE, GAN_CKECKPOINT from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, LOG_INTERVAL, SAVE_FOLDER, TRIAL_NAME, DATA_NAME, MULTI_GPU from config import FID_INTERVAL, FID_BATCH_NBR, PRETRAINED_AE_PATH from config import data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3 real_features = None inception = load_patched_inception_v3().cuda() inception.eval() percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True) saved_image_folder = saved_model_folder = None log_file_path = None if saved_image_folder is None: saved_image_folder, saved_model_folder = make_folders( SAVE_FOLDER, 'GAN_' + TRIAL_NAME) log_file_path = saved_image_folder + '/../gan_log.txt' log_file = open(log_file_path, 'w') log_file.close() dataset = PairedMultiDataset(data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3, im_size=IM_SIZE_GAN, rand_crop=True) print('the dataset contains %d images.' % len(dataset)) dataloader = iter( DataLoader(dataset, BATCH_SIZE_GAN, sampler=InfiniteSamplerWrapper(dataset), num_workers=DATALOADER_WORKERS, pin_memory=True)) from datasets import ImageFolder from datasets import trans_maker_augment as trans_maker dataset_rgb = ImageFolder(data_root_colorful, trans_maker(512)) dataset_skt = ImageFolder(data_root_sketch_3, trans_maker(512)) net_ae = AE(nfc=NFC, nbr_cls=NBR_CLS) if PRETRAINED_AE_PATH is None: PRETRAINED_AE_PATH = 'train_results/' + 'AE_' + TRIAL_NAME + '/models/%d.pth' % ITERATION_AE else: from config import PRETRAINED_AE_ITER PRETRAINED_AE_PATH = PRETRAINED_AE_PATH + '/models/%d.pth' % PRETRAINED_AE_ITER net_ae.load_state_dicts(PRETRAINED_AE_PATH) net_ae.cuda() net_ae.eval() RefineGenerator = None if DATA_NAME == 'celeba': from models import RefineGenerator_face as RefineGenerator elif DATA_NAME == 'art' or DATA_NAME == 'shoe': from models import RefineGenerator_art as RefineGenerator net_ig = RefineGenerator(nfc=NFC, im_size=IM_SIZE_GAN).cuda() net_id = Discriminator(nc=3).cuda( ) # we use the patch_gan, so the im_size for D should be 512 even if training image size is 1024 if MULTI_GPU: net_ae = nn.DataParallel(net_ae) net_ig = nn.DataParallel(net_ig) net_id = nn.DataParallel(net_id) net_ig_ema = copy_G_params(net_ig) opt_ig = optim.Adam(net_ig.parameters(), lr=2e-4, betas=(0.5, 0.999)) opt_id = optim.Adam(net_id.parameters(), lr=2e-4, betas=(0.5, 0.999)) if GAN_CKECKPOINT is not None: ckpt = torch.load(GAN_CKECKPOINT) net_ig.load_state_dict(ckpt['ig']) net_id.load_state_dict(ckpt['id']) net_ig_ema = ckpt['ig_ema'] opt_ig.load_state_dict(ckpt['opt_ig']) opt_id.load_state_dict(ckpt['opt_id']) ## create a log file losses_g_img = AverageMeter() losses_d_img = AverageMeter() losses_mse = AverageMeter() losses_rec_s = AverageMeter() losses_rec_ae = AverageMeter() fixed_skt = fixed_rgb = fixed_perm = None fid = [[0, 0]] for epoch in range(EPOCH_GAN): for iteration in tqdm(range(10000)): rgb_img, skt_img_1, skt_img_2, skt_img_3 = next(dataloader) rgb_img = rgb_img.cuda() rd = random.randint(0, 3) if rd == 0: skt_img = skt_img_1.cuda() elif rd == 1: skt_img = skt_img_2.cuda() else: skt_img = skt_img_3.cuda() if iteration == 0: fixed_skt = skt_img_3[:8].clone().cuda() fixed_rgb = rgb_img[:8].clone() fixed_perm = true_randperm(fixed_rgb.shape[0], 'cuda') ### 1. train D gimg_ae, style_feats = net_ae(skt_img, rgb_img) g_image = net_ig(gimg_ae, style_feats) pred_r = net_id(rgb_img) pred_f = net_id(g_image.detach()) loss_d = d_hinge_loss(pred_r, pred_f) net_id.zero_grad() loss_d.backward() opt_id.step() loss_rec_ae = F.mse_loss(gimg_ae, rgb_img) + F.l1_loss( gimg_ae, rgb_img) losses_rec_ae.update(loss_rec_ae.item(), BATCH_SIZE_GAN) ### 2. train G pred_g = net_id(g_image) loss_g = g_hinge_loss(pred_g) if DATA_NAME == 'shoe': loss_mse = 10 * (F.l1_loss(g_image, rgb_img) + F.mse_loss(g_image, rgb_img)) else: loss_mse = 10 * percept( F.adaptive_avg_pool2d(g_image, output_size=256), F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum() losses_mse.update(loss_mse.item() / BATCH_SIZE_GAN, BATCH_SIZE_GAN) loss_all = loss_g + loss_mse if DATA_NAME == 'shoe': ### the grey image reconstruction perm = true_randperm(BATCH_SIZE_GAN) img_ae_perm, style_feats_perm = net_ae(skt_img, rgb_img[perm]) gimg_grey = net_ig(img_ae_perm, style_feats_perm) gimg_grey = gimg_grey.mean(dim=1, keepdim=True) real_grey = rgb_img.mean(dim=1, keepdim=True) loss_rec_grey = F.mse_loss(gimg_grey, real_grey) loss_all += 10 * loss_rec_grey net_ig.zero_grad() loss_all.backward() opt_ig.step() for p, avg_p in zip(net_ig.parameters(), net_ig_ema): avg_p.mul_(0.999).add_(p.data, alpha=0.001) ### 3. logging losses_g_img.update(pred_g.mean().item(), BATCH_SIZE_GAN) losses_d_img.update(pred_r.mean().item(), BATCH_SIZE_GAN) if iteration % SAVE_IMAGE_INTERVAL == 0: #show the current images with torch.no_grad(): backup_para_g = copy_G_params(net_ig) load_params(net_ig, net_ig_ema) gimg_ae, style_feats = net_ae(fixed_skt, fixed_rgb) gmatch = net_ig(gimg_ae, style_feats) gimg_ae_perm, style_feats = net_ae(fixed_skt, fixed_rgb[fixed_perm]) gmismatch = net_ig(gimg_ae_perm, style_feats) gimg = torch.cat([ F.interpolate(fixed_rgb, IM_SIZE_GAN), F.interpolate(fixed_skt.repeat(1, 3, 1, 1), IM_SIZE_GAN), gmatch, F.interpolate(gimg_ae, IM_SIZE_GAN), gmismatch, F.interpolate(gimg_ae_perm, IM_SIZE_GAN) ]) vutils.save_image( gimg, f'{saved_image_folder}/img_iter_{epoch}_{iteration}.jpg', normalize=True, range=(-1, 1)) del gimg make_matrix( dataset_rgb, dataset_skt, net_ae, net_ig, 5, f'{saved_image_folder}/img_iter_{epoch}_{iteration}_matrix.jpg' ) load_params(net_ig, backup_para_g) if iteration % LOG_INTERVAL == 0: log_msg = 'Iter: [{0}/{1}] G: {losses_g_img.avg:.4f} D: {losses_d_img.avg:.4f} MSE: {losses_mse.avg:.4f} Rec: {losses_rec_s.avg:.5f} FID: {fid:.4f}'.format( epoch, iteration, losses_g_img=losses_g_img, losses_d_img=losses_d_img, losses_mse=losses_mse, losses_rec_s=losses_rec_s, fid=fid[-1][0]) print(log_msg) print('%.5f' % (losses_rec_ae.avg)) if log_file_path is not None: log_file = open(log_file_path, 'a') log_file.write(log_msg + '\n') log_file.close() losses_g_img.reset() losses_d_img.reset() losses_mse.reset() losses_rec_s.reset() losses_rec_ae.reset() if iteration % SAVE_MODEL_INTERVAL == 0 or iteration + 1 == 10000: print('Saving history model') torch.save( { 'ig': net_ig.state_dict(), 'id': net_id.state_dict(), 'ae': net_ae.state_dict(), 'ig_ema': net_ig_ema, 'opt_ig': opt_ig.state_dict(), 'opt_id': opt_id.state_dict(), }, '%s/%d.pth' % (saved_model_folder, epoch)) if iteration % FID_INTERVAL == 0 and iteration > 1: print("calculating FID ...") fid_batch_images = FID_BATCH_NBR if real_features is None: if os.path.exists('%s_fid_feats.npy' % (DATA_NAME)): real_features = pickle.load( open('%s_fid_feats.npy' % (DATA_NAME), 'rb')) else: real_features = extract_feature_from_generator_fn( real_image_loader(dataloader, n_batches=fid_batch_images), inception) real_mean = np.mean(real_features, 0) real_cov = np.cov(real_features, rowvar=False) pickle.dump( { 'feats': real_features, 'mean': real_mean, 'cov': real_cov }, open('%s_fid_feats.npy' % (DATA_NAME), 'wb')) real_features = pickle.load( open('%s_fid_feats.npy' % (DATA_NAME), 'rb')) sample_features = extract_feature_from_generator_fn( image_generator(dataset, net_ae, net_ig, n_batches=fid_batch_images), inception, total=fid_batch_images) cur_fid = calc_fid(sample_features, real_mean=real_features['mean'], real_cov=real_features['cov']) sample_features_perm = extract_feature_from_generator_fn( image_generator_perm(dataset, net_ae, net_ig, n_batches=fid_batch_images), inception, total=fid_batch_images) cur_fid_perm = calc_fid(sample_features_perm, real_mean=real_features['mean'], real_cov=real_features['cov']) fid.append([cur_fid, cur_fid_perm]) print('fid:', fid) if log_file_path is not None: log_file = open(log_file_path, 'a') log_msg = 'fid: %.5f, %.5f' % (fid[-1][0], fid[-1][1]) log_file.write(log_msg + '\n') log_file.close()
def train(): from config import IM_SIZE_AE, BATCH_SIZE_AE, NFC, NBR_CLS, DATALOADER_WORKERS, ITERATION_AE from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, SAVE_FOLDER, TRIAL_NAME, LOG_INTERVAL from config import DATA_NAME from config import data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3 dataset = PairedMultiDataset(data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3, im_size=IM_SIZE_AE, rand_crop=True) print(len(dataset)) dataloader = iter(DataLoader(dataset, BATCH_SIZE_AE, \ sampler=InfiniteSamplerWrapper(dataset), num_workers=DATALOADER_WORKERS, pin_memory=True)) dataset_ss = SelfSupervisedDataset(data_root_colorful, data_root_sketch_3, im_size=IM_SIZE_AE, nbr_cls=NBR_CLS, rand_crop=True) print(len(dataset_ss), len(dataset_ss.frame)) dataloader_ss = iter(DataLoader(dataset_ss, BATCH_SIZE_AE, \ sampler=InfiniteSamplerWrapper(dataset_ss), num_workers=DATALOADER_WORKERS, pin_memory=True)) style_encoder = StyleEncoder(nfc=NFC, nbr_cls=NBR_CLS).cuda() content_encoder = ContentEncoder(nfc=NFC).cuda() decoder = Decoder(nfc=NFC).cuda() opt_c = optim.Adam(content_encoder.parameters(), lr=2e-4, betas=(0.5, 0.999)) opt_s = optim.Adam(style_encoder.parameters(), lr=2e-4, betas=(0.5, 0.999)) opt_d = optim.Adam(decoder.parameters(), lr=2e-4, betas=(0.5, 0.999)) style_encoder.reset_cls() style_encoder.final_cls.cuda() from config import PRETRAINED_AE_PATH, PRETRAINED_AE_ITER if PRETRAINED_AE_PATH is not None: PRETRAINED_AE_PATH = PRETRAINED_AE_PATH + '/models/%d.pth' % PRETRAINED_AE_ITER ckpt = torch.load(PRETRAINED_AE_PATH) print(PRETRAINED_AE_PATH) style_encoder.load_state_dict(ckpt['s']) content_encoder.load_state_dict(ckpt['c']) decoder.load_state_dict(ckpt['d']) opt_c.load_state_dict(ckpt['opt_c']) opt_s.load_state_dict(ckpt['opt_s']) opt_d.load_state_dict(ckpt['opt_d']) print('loaded pre-trained AE') style_encoder.reset_cls() style_encoder.final_cls.cuda() opt_s_cls = optim.Adam(style_encoder.final_cls.parameters(), lr=2e-4, betas=(0.5, 0.999)) saved_image_folder, saved_model_folder = make_folders( SAVE_FOLDER, 'AE_' + TRIAL_NAME) log_file_path = saved_image_folder + '/../ae_log.txt' log_file = open(log_file_path, 'w') log_file.close() ## for logging losses_sf_consist = AverageMeter() losses_cf_consist = AverageMeter() losses_cls = AverageMeter() losses_rec_rd = AverageMeter() losses_rec_org = AverageMeter() losses_rec_grey = AverageMeter() import lpips percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True) for iteration in tqdm(range(ITERATION_AE)): if iteration % ( (NBR_CLS * 100) // BATCH_SIZE_AE) == 0 and iteration > 1: dataset_ss._next_set() dataloader_ss = iter( DataLoader(dataset_ss, BATCH_SIZE_AE, sampler=InfiniteSamplerWrapper(dataset_ss), num_workers=DATALOADER_WORKERS, pin_memory=True)) style_encoder.reset_cls() opt_s_cls = optim.Adam(style_encoder.final_cls.parameters(), lr=2e-4, betas=(0.5, 0.999)) opt_s.param_groups[0]['lr'] = 1e-4 opt_d.param_groups[0]['lr'] = 1e-4 ### 1. train the encoder with self-supervision methods rgb_img_rd, rgb_img_org, skt_org, skt_bold, skt_erased, skt_erased_bold, img_idx = next( dataloader_ss) rgb_img_rd = rgb_img_rd.cuda() rgb_img_org = rgb_img_org.cuda() img_idx = img_idx.cuda() skt_org = F.interpolate(skt_org, size=512).cuda() skt_bold = F.interpolate(skt_bold, size=512).cuda() skt_erased = F.interpolate(skt_erased, size=512).cuda() skt_erased_bold = F.interpolate(skt_erased_bold, size=512).cuda() style_encoder.zero_grad() decoder.zero_grad() content_encoder.zero_grad() style_vector_rd, pred_cls_rd = style_encoder(rgb_img_rd) style_vector_org, pred_cls_org = style_encoder(rgb_img_org) content_feats = content_encoder(skt_org) content_feats_bold = content_encoder(skt_bold) content_feats_erased = content_encoder(skt_erased) content_feats_eb = content_encoder(skt_erased_bold) rd = random.randint(0, 3) gimg_rd = None if rd == 0: gimg_rd = decoder(content_feats, style_vector_rd) elif rd == 1: gimg_rd = decoder(content_feats_bold, style_vector_rd) elif rd == 2: gimg_rd = decoder(content_feats_erased, style_vector_rd) elif rd == 3: gimg_rd = decoder(content_feats_eb, style_vector_rd) loss_cf_consist = loss_for_list_perm(F.mse_loss, content_feats_bold, content_feats) +\ loss_for_list_perm(F.mse_loss, content_feats_erased, content_feats) +\ loss_for_list_perm(F.mse_loss, content_feats_eb, content_feats) loss_sf_consist = 0 for loss_idx in range(3): loss_sf_consist += -F.cosine_similarity(style_vector_rd[loss_idx], style_vector_org[loss_idx].detach()).mean() + \ F.cosine_similarity(style_vector_rd[loss_idx], style_vector_org[loss_idx][torch.randperm(BATCH_SIZE_AE)].detach()).mean() loss_cls = F.cross_entropy(pred_cls_rd, img_idx) + F.cross_entropy( pred_cls_org, img_idx) loss_rec_rd = F.mse_loss(gimg_rd, rgb_img_org) if DATA_NAME != 'shoe': loss_rec_rd += percept( F.adaptive_avg_pool2d(gimg_rd, output_size=256), F.adaptive_avg_pool2d(rgb_img_org, output_size=256)).sum() else: loss_rec_rd += F.l1_loss(gimg_rd, rgb_img_org) loss_total = loss_cls + loss_sf_consist + loss_rec_rd + loss_cf_consist #+ loss_kl_c + loss_kl_s loss_total.backward() opt_s.step() opt_s_cls.step() opt_c.step() opt_d.step() ### 2. train as AutoEncoder rgb_img, skt_img_1, skt_img_2, skt_img_3 = next(dataloader) rgb_img = rgb_img.cuda() rd = random.randint(0, 3) if rd == 0: skt_img = skt_img_1 elif rd == 1: skt_img = skt_img_2 else: skt_img = skt_img_3 skt_img = F.interpolate(skt_img, size=512).cuda() style_encoder.zero_grad() decoder.zero_grad() content_encoder.zero_grad() style_vector, _ = style_encoder(rgb_img) content_feats = content_encoder(skt_img) gimg = decoder(content_feats, style_vector) loss_rec_org = F.mse_loss(gimg, rgb_img) if DATA_NAME != 'shoe': loss_rec_org += percept( F.adaptive_avg_pool2d(gimg, output_size=256), F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum() #else: # loss_rec_org += F.l1_loss(gimg, rgb_img) loss_rec = loss_rec_org if DATA_NAME == 'shoe': ### the grey image reconstruction perm = true_randperm(BATCH_SIZE_AE) gimg_perm = decoder(content_feats, [s[perm] for s in style_vector]) gimg_grey = gimg_perm.mean(dim=1, keepdim=True) real_grey = rgb_img.mean(dim=1, keepdim=True) loss_rec_grey = F.mse_loss(gimg_grey, real_grey) loss_rec += loss_rec_grey loss_rec.backward() opt_s.step() opt_d.step() opt_c.step() ### Logging losses_cf_consist.update(loss_cf_consist.mean().item(), BATCH_SIZE_AE) losses_sf_consist.update(loss_sf_consist.mean().item(), BATCH_SIZE_AE) losses_cls.update(loss_cls.mean().item(), BATCH_SIZE_AE) losses_rec_rd.update(loss_rec_rd.item(), BATCH_SIZE_AE) losses_rec_org.update(loss_rec_org.item(), BATCH_SIZE_AE) if DATA_NAME == 'shoe': losses_rec_grey.update(loss_rec_grey.item(), BATCH_SIZE_AE) if iteration % LOG_INTERVAL == 0: log_msg = 'Train Stage 1: AE: \nrec_rd: %.4f rec_org: %.4f cls: %.4f style_consist: %.4f content_consist: %.4f rec_grey: %.4f'%(losses_rec_rd.avg, \ losses_rec_org.avg, losses_cls.avg, losses_sf_consist.avg, losses_cf_consist.avg, losses_rec_grey.avg) print(log_msg) if log_file_path is not None: log_file = open(log_file_path, 'a') log_file.write(log_msg + '\n') log_file.close() losses_sf_consist.reset() losses_cls.reset() losses_rec_rd.reset() losses_rec_org.reset() losses_cf_consist.reset() losses_rec_grey.reset() if iteration % SAVE_IMAGE_INTERVAL == 0: vutils.save_image(torch.cat([ rgb_img_rd, F.interpolate(skt_org.repeat(1, 3, 1, 1), size=512), gimg_rd ]), '%s/rd_%d.jpg' % (saved_image_folder, iteration), normalize=True, range=(-1, 1)) if DATA_NAME != 'shoe': with torch.no_grad(): perm = true_randperm(BATCH_SIZE_AE) gimg_perm = decoder([c for c in content_feats], [s[perm] for s in style_vector]) vutils.save_image(torch.cat([ rgb_img, F.interpolate(skt_img.repeat(1, 3, 1, 1), size=512), gimg, gimg_perm ]), '%s/org_%d.jpg' % (saved_image_folder, iteration), normalize=True, range=(-1, 1)) if iteration % SAVE_MODEL_INTERVAL == 0: print('Saving history model') torch.save( { 's': style_encoder.state_dict(), 'd': decoder.state_dict(), 'c': content_encoder.state_dict(), 'opt_c': opt_c.state_dict(), 'opt_s_cls': opt_s_cls.state_dict(), 'opt_s': opt_s.state_dict(), 'opt_d': opt_d.state_dict(), }, '%s/%d.pth' % (saved_model_folder, iteration)) torch.save( { 's': style_encoder.state_dict(), 'd': decoder.state_dict(), 'c': content_encoder.state_dict(), 'opt_c': opt_c.state_dict(), 'opt_s_cls': opt_s_cls.state_dict(), 'opt_s': opt_s.state_dict(), 'opt_d': opt_d.state_dict(), }, '%s/%d.pth' % (saved_model_folder, ITERATION_AE))