def __init__(self, args): self.use_cuda = args.cuda and torch.cuda.is_available() self.max_iter = args.max_iter # self.global_iter = 0 self.z_dim = args.z_dim self.beta = args.beta self.objective = args.objective self.model = args.model self.lr = args.lr self.beta1 = args.beta1 self.beta2 = args.beta2 # model params self.c_dim = args.c_dim self.image_size = args.image_size self.g_conv_dim = args.g_conv_dim self.g_repeat_num = args.g_repeat_num self.d_conv_dim = args.d_conv_dim self.d_repeat_num = args.d_repeat_num self.norm_layer = get_norm_layer(norm_type=args.norm) '''arrangement for each domain''' self.z_content_dim = args.z_content self.z_size_dim = args.z_size self.z_font_color_dim = args.z_font_color self.z_back_color_dim = args.z_back_color self.z_style_dim = args.z_style self.z_content_start_dim = 0 self.z_size_start_dim = 20 self.z_font_color_start_dim = 40 self.z_back_color_start_dim = 60 self.z_style_start_dim = 80 self.lambda_combine = args.lambda_combine self.lambda_unsup = args.lambda_unsup if args.dataset.lower() == 'dsprites': self.nc = 1 self.decoder_dist = 'bernoulli' elif args.dataset.lower() == '3dchairs': self.nc = 3 self.decoder_dist = 'gaussian' elif args.dataset.lower() == 'celeba': self.nc = 3 self.decoder_dist = 'gaussian' elif args.dataset.lower() == 'ilab_unsup': self.nc = 3 self.decoder_dist = 'gaussian' elif args.dataset.lower() == 'ilab_sup': self.nc = 3 self.decoder_dist = 'gaussian' elif args.dataset.lower() == 'ilab_unsup_unbalance': self.nc = 3 self.decoder_dist = 'gaussian' elif args.dataset.lower() == 'ilab_unsup_unbalance_free': self.nc = 3 self.decoder_dist = 'gaussian' elif args.dataset.lower() == 'ilab_unsup_threeswap': self.nc = 3 self.decoder_dist = 'gaussian' elif args.dataset.lower() == 'fonts_unsup_nswap': self.nc = 3 self.decoder_dist = 'gaussian' else: raise NotImplementedError self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') # model # self.Autoencoder = Generator(self.nc, self.g_conv_dim, self.g_repeat_num) self.Autoencoder = Generator_fc(self.nc, self.g_conv_dim, self.g_repeat_num, self.z_dim) # self.Autoencoder = BetaVAE_ilab(self.z_dim, self.nc) self.Autoencoder.to(self.device) self.auto_optim = optim.Adam(self.Autoencoder.parameters(), lr=self.lr, betas=(self.beta1, self.beta2)) ''' use D ''' # self.netD = networks.define_D(self.nc, self.d_conv_dim, 'basic', # 3, 'instance', True, 'normal', 0.02, # '0,1') # log self.log_dir = './checkpoints/' + args.viz_name self.model_save_dir = args.model_save_dir self.viz_name = args.viz_name self.viz_port = args.viz_port self.viz_on = args.viz_on self.win_recon = None self.win_combine_sup = None self.win_combine_unsup = None # self.win_d_no_pose_losdata_loaders = None # self.win_d_pose_loss = None # self.win_equal_pose_loss = None # self.win_have_pose_loss = None # self.win_auto_loss_fake = None # self.win_loss_cor_coe = None # self.win_d_loss = None if self.viz_on: self.viz = visdom.Visdom(port=self.viz_port) self.resume_iters = args.resume_iters self.ckpt_dir = os.path.join(args.ckpt_dir, args.viz_name) if not os.path.exists(self.ckpt_dir): os.makedirs(self.ckpt_dir, exist_ok=True) self.ckpt_name = args.ckpt_name # if self.ckpt_name is not None: # self.load_checkpoint(self.ckpt_name) self.save_output = args.save_output self.output_dir = os.path.join(args.output_dir, args.viz_name) if not os.path.exists(self.output_dir): os.makedirs(self.output_dir, exist_ok=True) self.gather_step = args.gather_step self.display_step = args.display_step self.save_step = args.save_step self.dset_dir = args.dset_dir self.dataset = args.dataset self.batch_size = args.batch_size self.data_loader = return_data(args) self.gather = DataGather()
class Solver(object): def __init__(self, args): self.use_cuda = args.cuda and torch.cuda.is_available() self.max_iter = args.max_iter # self.global_iter = 0 self.z_dim = args.z_dim self.beta = args.beta self.objective = args.objective self.model = args.model self.lr = args.lr self.beta1 = args.beta1 self.beta2 = args.beta2 # model params self.c_dim = args.c_dim self.image_size = args.image_size self.g_conv_dim = args.g_conv_dim self.g_repeat_num = args.g_repeat_num self.d_conv_dim = args.d_conv_dim self.d_repeat_num = args.d_repeat_num self.norm_layer = get_norm_layer(norm_type=args.norm) '''arrangement for each domain''' self.z_content_dim = args.z_content self.z_size_dim = args.z_size self.z_font_color_dim = args.z_font_color self.z_back_color_dim = args.z_back_color self.z_style_dim = args.z_style self.z_content_start_dim = 0 self.z_size_start_dim = 20 self.z_font_color_start_dim = 40 self.z_back_color_start_dim = 60 self.z_style_start_dim = 80 self.lambda_combine = args.lambda_combine self.lambda_unsup = args.lambda_unsup if args.dataset.lower() == 'dsprites': self.nc = 1 self.decoder_dist = 'bernoulli' elif args.dataset.lower() == '3dchairs': self.nc = 3 self.decoder_dist = 'gaussian' elif args.dataset.lower() == 'celeba': self.nc = 3 self.decoder_dist = 'gaussian' elif args.dataset.lower() == 'ilab_unsup': self.nc = 3 self.decoder_dist = 'gaussian' elif args.dataset.lower() == 'ilab_sup': self.nc = 3 self.decoder_dist = 'gaussian' elif args.dataset.lower() == 'ilab_unsup_unbalance': self.nc = 3 self.decoder_dist = 'gaussian' elif args.dataset.lower() == 'ilab_unsup_unbalance_free': self.nc = 3 self.decoder_dist = 'gaussian' elif args.dataset.lower() == 'ilab_unsup_threeswap': self.nc = 3 self.decoder_dist = 'gaussian' elif args.dataset.lower() == 'fonts_unsup_nswap': self.nc = 3 self.decoder_dist = 'gaussian' else: raise NotImplementedError self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') # model # self.Autoencoder = Generator(self.nc, self.g_conv_dim, self.g_repeat_num) self.Autoencoder = Generator_fc(self.nc, self.g_conv_dim, self.g_repeat_num, self.z_dim) # self.Autoencoder = BetaVAE_ilab(self.z_dim, self.nc) self.Autoencoder.to(self.device) self.auto_optim = optim.Adam(self.Autoencoder.parameters(), lr=self.lr, betas=(self.beta1, self.beta2)) ''' use D ''' # self.netD = networks.define_D(self.nc, self.d_conv_dim, 'basic', # 3, 'instance', True, 'normal', 0.02, # '0,1') # log self.log_dir = './checkpoints/' + args.viz_name self.model_save_dir = args.model_save_dir self.viz_name = args.viz_name self.viz_port = args.viz_port self.viz_on = args.viz_on self.win_recon = None self.win_combine_sup = None self.win_combine_unsup = None # self.win_d_no_pose_losdata_loaders = None # self.win_d_pose_loss = None # self.win_equal_pose_loss = None # self.win_have_pose_loss = None # self.win_auto_loss_fake = None # self.win_loss_cor_coe = None # self.win_d_loss = None if self.viz_on: self.viz = visdom.Visdom(port=self.viz_port) self.resume_iters = args.resume_iters self.ckpt_dir = os.path.join(args.ckpt_dir, args.viz_name) if not os.path.exists(self.ckpt_dir): os.makedirs(self.ckpt_dir, exist_ok=True) self.ckpt_name = args.ckpt_name # if self.ckpt_name is not None: # self.load_checkpoint(self.ckpt_name) self.save_output = args.save_output self.output_dir = os.path.join(args.output_dir, args.viz_name) if not os.path.exists(self.output_dir): os.makedirs(self.output_dir, exist_ok=True) self.gather_step = args.gather_step self.display_step = args.display_step self.save_step = args.save_step self.dset_dir = args.dset_dir self.dataset = args.dataset self.batch_size = args.batch_size self.data_loader = return_data(args) self.gather = DataGather() def restore_model(self, resume_iters): """Restore the trained generator and discriminator.""" print( 'Loading the trained models from step {}...'.format(resume_iters)) Auto_path = os.path.join(self.model_save_dir, self.viz_name, '{}-Auto.ckpt'.format(resume_iters)) self.Autoencoder.load_state_dict( torch.load(Auto_path, map_location=lambda storage, loc: storage)) print("=> loaded checkpoint '{} (iter {})'".format( self.viz_name, resume_iters)) def Cor_CoeLoss(self, y_pred, y_target): x = y_pred y = y_target x_var = x - torch.mean(x) y_var = y - torch.mean(y) r_num = torch.sum(x_var * y_var) r_den = torch.sqrt(torch.sum(x_var**2)) * torch.sqrt( torch.sum(y_var**2)) r = r_num / r_den # return 1 - r # best are 0 return 1 - r**2 # abslute constrain def train(self): # self.net_mode(train=True) out = False # Start training from scratch or resume training. self.global_iter = 0 if self.resume_iters: self.global_iter = self.resume_iters self.restore_model(self.resume_iters) pbar = tqdm(total=self.max_iter) pbar.update(self.global_iter) while not out: for sup_package in self.data_loader: # appe, pose, combine A_img = sup_package['A'] B_img = sup_package['B'] C_img = sup_package['C'] D_img = sup_package['D'] E_img = sup_package['E'] F_img = sup_package['F'] self.global_iter += 1 pbar.update(1) A_img = Variable(cuda(A_img, self.use_cuda)) B_img = Variable(cuda(B_img, self.use_cuda)) C_img = Variable(cuda(C_img, self.use_cuda)) D_img = Variable(cuda(D_img, self.use_cuda)) E_img = Variable(cuda(E_img, self.use_cuda)) F_img = Variable(cuda(F_img, self.use_cuda)) ## 1. A B C seperate(first400: id last600 background) A_recon, A_z = self.Autoencoder(A_img) B_recon, B_z = self.Autoencoder(B_img) C_recon, C_z = self.Autoencoder(C_img) D_recon, D_z = self.Autoencoder(D_img) E_recon, E_z = self.Autoencoder(E_img) F_recon, F_z = self.Autoencoder(F_img) ''' refer 1: content, 2: size, 3: font-color, 4 back_color, 5 style''' A_z_1 = A_z[:, 0:self.z_size_start_dim] # 0-200 A_z_2 = A_z[:, self.z_size_start_dim: self.z_font_color_start_dim] # 20-40 A_z_3 = A_z[:, self.z_font_color_start_dim: self.z_back_color_start_dim] #40-60 A_z_4 = A_z[:, self.z_back_color_start_dim: self.z_style_start_dim] # 60-80 A_z_5 = A_z[:, self.z_style_start_dim:] #80-100 B_z_1 = B_z[:, 0:self.z_size_start_dim] # 0-200 B_z_2 = B_z[:, self.z_size_start_dim: self.z_font_color_start_dim] # 200-400 B_z_3 = B_z[:, self.z_font_color_start_dim: self.z_back_color_start_dim] #400-600 B_z_4 = B_z[:, self.z_back_color_start_dim: self.z_style_start_dim] # 600-800 B_z_5 = B_z[:, self.z_style_start_dim:] #800-1000 C_z_1 = C_z[:, 0:self.z_size_start_dim] # 0-200 C_z_2 = C_z[:, self.z_size_start_dim: self.z_font_color_start_dim] # 200-400 C_z_3 = C_z[:, self.z_font_color_start_dim: self.z_back_color_start_dim] #400-600 C_z_4 = C_z[:, self.z_back_color_start_dim: self.z_style_start_dim] # 600-800 C_z_5 = C_z[:, self.z_style_start_dim:] #800-1000 D_z_1 = D_z[:, 0:self.z_size_start_dim] # 0-200 D_z_2 = D_z[:, self.z_size_start_dim: self.z_font_color_start_dim] # 200-400 D_z_3 = D_z[:, self.z_font_color_start_dim: self.z_back_color_start_dim] #400-600 D_z_4 = D_z[:, self.z_back_color_start_dim: self.z_style_start_dim] # 600-800 D_z_5 = D_z[:, self.z_style_start_dim:] #800-1000 E_z_1 = E_z[:, 0:self.z_size_start_dim] # 0-200 E_z_2 = E_z[:, self.z_size_start_dim: self.z_font_color_start_dim] # 200-400 E_z_3 = E_z[:, self.z_font_color_start_dim: self.z_back_color_start_dim] #400-600 E_z_4 = E_z[:, self.z_back_color_start_dim: self.z_style_start_dim] # 600-800 E_z_5 = E_z[:, self.z_style_start_dim:] #800-1000 F_z_1 = F_z[:, 0:self.z_size_start_dim] # 0-200 F_z_2 = F_z[:, self.z_size_start_dim: self.z_font_color_start_dim] # 200-400 F_z_3 = F_z[:, self.z_font_color_start_dim: self.z_back_color_start_dim] #400-600 F_z_4 = F_z[:, self.z_back_color_start_dim: self.z_style_start_dim] # 600-800 F_z_5 = F_z[:, self.z_style_start_dim:] #800-1000 ## 2. combine with strong supervise ''' refer 1: content, 2: size, 3: font-color, 4 back_color, 5 style''' # C A same content-1 A1Co_combine_2C = torch.cat( (A_z_1, C_z_2, C_z_3, C_z_4, C_z_5), dim=1) mid_A1Co = self.Autoencoder.fc_decoder(A1Co_combine_2C) mid_A1Co = mid_A1Co.view(A1Co_combine_2C.shape[0], 256, 8, 8) A1Co_2C = self.Autoencoder.decoder(mid_A1Co) AoC1_combine_2A = torch.cat( (C_z_1, A_z_2, A_z_3, A_z_4, A_z_5), dim=1) mid_AoC1 = self.Autoencoder.fc_decoder(AoC1_combine_2A) mid_AoC1 = mid_AoC1.view(AoC1_combine_2A.shape[0], 256, 8, 8) AoC1_2A = self.Autoencoder.decoder(mid_AoC1) # C B same size 2 B2Co_combine_2C = torch.cat( (C_z_1, B_z_2, C_z_3, C_z_4, C_z_5), dim=1) mid_B2Co = self.Autoencoder.fc_decoder(B2Co_combine_2C) mid_B2Co = mid_B2Co.view(B2Co_combine_2C.shape[0], 256, 8, 8) B2Co_2C = self.Autoencoder.decoder(mid_B2Co) BoC2_combine_2B = torch.cat( (B_z_1, C_z_2, B_z_3, B_z_4, B_z_5), dim=1) mid_BoC2 = self.Autoencoder.fc_decoder(BoC2_combine_2B) mid_BoC2 = mid_BoC2.view(BoC2_combine_2B.shape[0], 256, 8, 8) BoC2_2B = self.Autoencoder.decoder(mid_BoC2) # C D same font_color 3 D3Co_combine_2C = torch.cat( (C_z_1, C_z_2, D_z_3, C_z_4, C_z_5), dim=1) mid_D3Co = self.Autoencoder.fc_decoder(D3Co_combine_2C) mid_D3Co = mid_D3Co.view(D3Co_combine_2C.shape[0], 256, 8, 8) D3Co_2C = self.Autoencoder.decoder(mid_D3Co) DoC3_combine_2D = torch.cat( (D_z_1, D_z_2, C_z_3, D_z_4, D_z_5), dim=1) mid_DoC3 = self.Autoencoder.fc_decoder(DoC3_combine_2D) mid_DoC3 = mid_DoC3.view(DoC3_combine_2D.shape[0], 256, 8, 8) DoC3_2D = self.Autoencoder.decoder(mid_DoC3) # C E same back_color 4 E4Co_combine_2C = torch.cat( (C_z_1, C_z_2, C_z_3, E_z_4, C_z_5), dim=1) mid_E4Co = self.Autoencoder.fc_decoder(E4Co_combine_2C) mid_E4Co = mid_E4Co.view(E4Co_combine_2C.shape[0], 256, 8, 8) E4Co_2C = self.Autoencoder.decoder(mid_E4Co) EoC4_combine_2E = torch.cat( (E_z_1, E_z_2, E_z_3, C_z_4, E_z_5), dim=1) mid_EoC4 = self.Autoencoder.fc_decoder(EoC4_combine_2E) mid_EoC4 = mid_EoC4.view(EoC4_combine_2E.shape[0], 256, 8, 8) EoC4_2E = self.Autoencoder.decoder(mid_EoC4) # C F same style 5 F5Co_combine_2C = torch.cat( (C_z_1, C_z_2, C_z_3, C_z_4, F_z_5), dim=1) mid_F5Co = self.Autoencoder.fc_decoder(F5Co_combine_2C) mid_F5Co = mid_F5Co.view(F5Co_combine_2C.shape[0], 256, 8, 8) F5Co_2C = self.Autoencoder.decoder(mid_F5Co) FoC5_combine_2F = torch.cat( (F_z_1, F_z_2, F_z_3, F_z_4, C_z_5), dim=1) mid_FoC5 = self.Autoencoder.fc_decoder(FoC5_combine_2F) mid_FoC5 = mid_FoC5.view(FoC5_combine_2F.shape[0], 256, 8, 8) FoC5_2F = self.Autoencoder.decoder(mid_FoC5) # combine_2C A1B2D3E4F5_combine_2C = torch.cat( (A_z_1, B_z_2, D_z_3, E_z_4, F_z_5), dim=1) mid_A1B2D3E4F5 = self.Autoencoder.fc_decoder( A1B2D3E4F5_combine_2C) mid_A1B2D3E4F5 = mid_A1B2D3E4F5.view( A1B2D3E4F5_combine_2C.shape[0], 256, 8, 8) A1B2D3E4F5_2C = self.Autoencoder.decoder(mid_A1B2D3E4F5) # ''' need unsupervise ''' A2B3D4E5F1_combine_2N = torch.cat( (F_z_1, A_z_2, B_z_3, D_z_4, E_z_5), dim=1) mid_A2B3D4E5F1 = self.Autoencoder.fc_decoder( A2B3D4E5F1_combine_2N) mid_A2B3D4E5F1 = mid_A2B3D4E5F1.view( A2B3D4E5F1_combine_2N.shape[0], 256, 8, 8) A2B3D4E5F1_2N = self.Autoencoder.decoder(mid_A2B3D4E5F1) ''' optimize for autoencoder ''' # 1. recon_loss A_recon_loss = torch.mean(torch.abs(A_img - A_recon)) B_recon_loss = torch.mean(torch.abs(B_img - B_recon)) C_recon_loss = torch.mean(torch.abs(C_img - C_recon)) D_recon_loss = torch.mean(torch.abs(D_img - D_recon)) E_recon_loss = torch.mean(torch.abs(E_img - E_recon)) F_recon_loss = torch.mean(torch.abs(F_img - F_recon)) recon_loss = A_recon_loss + B_recon_loss + C_recon_loss + D_recon_loss + E_recon_loss + F_recon_loss # 2. sup_combine_loss A1Co_2C_loss = torch.mean(torch.abs(C_img - A1Co_2C)) AoC1_2A_loss = torch.mean(torch.abs(A_img - AoC1_2A)) B2Co_2C_loss = torch.mean(torch.abs(C_img - B2Co_2C)) BoC2_2B_loss = torch.mean(torch.abs(B_img - BoC2_2B)) D3Co_2C_loss = torch.mean(torch.abs(C_img - D3Co_2C)) DoC3_2D_loss = torch.mean(torch.abs(D_img - DoC3_2D)) E4Co_2C_loss = torch.mean(torch.abs(C_img - E4Co_2C)) EoC4_2E_loss = torch.mean(torch.abs(E_img - EoC4_2E)) F5Co_2C_loss = torch.mean(torch.abs(C_img - F5Co_2C)) FoC5_2F_loss = torch.mean(torch.abs(F_img - FoC5_2F)) A1B2D3E4F5_2C_loss = torch.mean( torch.abs(C_img - A1B2D3E4F5_2C)) combine_sup_loss = A1Co_2C_loss + AoC1_2A_loss + B2Co_2C_loss + BoC2_2B_loss + D3Co_2C_loss + DoC3_2D_loss + E4Co_2C_loss + EoC4_2E_loss + F5Co_2C_loss + FoC5_2F_loss + A1B2D3E4F5_2C_loss # 3. unsup_combine_loss _, A2B3D4E5F1_z = self.Autoencoder(A2B3D4E5F1_2N) combine_unsup_loss = torch.mean(torch.abs(F_z_1 - A2B3D4E5F1_z[:, 0:self.z_size_start_dim])) + torch.mean(torch.abs(A_z_2 - A2B3D4E5F1_z[:, self.z_size_start_dim : self.z_font_color_start_dim])) \ + torch.mean(torch.abs(B_z_3 - A2B3D4E5F1_z[:, self.z_font_color_start_dim : self.z_back_color_start_dim])) \ + torch.mean(torch.abs(D_z_4 - A2B3D4E5F1_z[:, self.z_back_color_start_dim : self.z_style_start_dim])) \ + torch.mean(torch.abs(E_z_5 - A2B3D4E5F1_z[:, self.z_style_start_dim :])) # whole loss vae_unsup_loss = recon_loss + self.lambda_combine * combine_sup_loss + self.lambda_unsup * combine_unsup_loss self.auto_optim.zero_grad() vae_unsup_loss.backward() self.auto_optim.step() # save the log f = open(self.log_dir + '/log.txt', 'a') f.writelines([ '\n', '[{}] recon_loss:{:.3f} combine_sup_loss:{:.3f} combine_unsup_loss:{:.3f}' .format(self.global_iter, recon_loss.data, combine_sup_loss.data, combine_unsup_loss.data) ]) f.close() if self.viz_on and self.global_iter % self.gather_step == 0: self.gather.insert( iter=self.global_iter, recon_loss=recon_loss.data, combine_sup_loss=combine_sup_loss.data, combine_unsup_loss=combine_unsup_loss.data) if self.global_iter % self.display_step == 0: pbar.write( '[{}] recon_loss:{:.3f} combine_sup_loss:{:.3f} combine_unsup_loss:{:.3f}' .format(self.global_iter, recon_loss.data, combine_sup_loss.data, combine_unsup_loss.data)) if self.viz_on: self.gather.insert(images=A_img.data) self.gather.insert(images=B_img.data) self.gather.insert(images=C_img.data) self.gather.insert(images=D_img.data) self.gather.insert(images=E_img.data) self.gather.insert(images=F_img.data) self.gather.insert(images=F.sigmoid(A_recon).data) self.viz_reconstruction() self.viz_lines() ''' combine show ''' self.gather.insert( combine_supimages=F.sigmoid(AoC1_2A).data) self.gather.insert( combine_supimages=F.sigmoid(BoC2_2B).data) self.gather.insert( combine_supimages=F.sigmoid(D3Co_2C).data) self.gather.insert( combine_supimages=F.sigmoid(DoC3_2D).data) self.gather.insert( combine_supimages=F.sigmoid(EoC4_2E).data) self.gather.insert( combine_supimages=F.sigmoid(FoC5_2F).data) self.viz_combine_recon() self.gather.insert( combine_unsupimages=F.sigmoid(A1B2D3E4F5_2C).data) self.gather.insert( combine_unsupimages=F.sigmoid(A2B3D4E5F1_2N).data) self.viz_combine_unsuprecon() # self.viz_combine(x) self.gather.flush() # Save model checkpoints. if self.global_iter % self.save_step == 0: Auto_path = os.path.join( self.model_save_dir, self.viz_name, '{}-Auto.ckpt'.format(self.global_iter)) torch.save(self.Autoencoder.state_dict(), Auto_path) print('Saved model checkpoints into {}/{}...'.format( self.model_save_dir, self.viz_name)) if self.global_iter >= self.max_iter: out = True break pbar.write("[Training Finished]") pbar.close() def save_sample_img(self, tensor, mode): unloader = transforms.ToPILImage() dir = os.path.join(self.model_save_dir, self.viz_name, 'sample_img') if not os.path.exists(dir): os.makedirs(dir) image = tensor.cpu().clone( ) # we clone the tensor to not do changes on it if mode == 'recon': image_ori_A = image[0].squeeze( 0) # remove the fake batch dimension image_ori_B = image[1].squeeze(0) image_ori_C = image[2].squeeze(0) image_ori_D = image[3].squeeze(0) image_ori_E = image[4].squeeze(0) image_ori_F = image[5].squeeze(0) image_recon = image[6].squeeze(0) image_ori_A = unloader(image_ori_A) image_ori_B = unloader(image_ori_B) image_ori_C = unloader(image_ori_C) image_ori_D = unloader(image_ori_D) image_ori_E = unloader(image_ori_E) image_ori_F = unloader(image_ori_F) image_recon = unloader(image_recon) image_ori_A.save( os.path.join(dir, '{}-A_img.png'.format(self.global_iter))) image_ori_B.save( os.path.join(dir, '{}-B_img.png'.format(self.global_iter))) image_ori_C.save( os.path.join(dir, '{}-C_img.png'.format(self.global_iter))) image_ori_D.save( os.path.join(dir, '{}-D_img.png'.format(self.global_iter))) image_ori_E.save( os.path.join(dir, '{}-E_img.png'.format(self.global_iter))) image_ori_F.save( os.path.join(dir, '{}-F_img.png'.format(self.global_iter))) image_recon.save( os.path.join(dir, '{}-A_img_recon.png'.format(self.global_iter))) elif mode == 'combine_sup': image_AoC1_2A = image[0].squeeze( 0) # remove the fake batch dimension image_BoC2_2B = image[1].squeeze(0) image_D3Co_2C = image[2].squeeze(0) image_DoC3_2D = image[3].squeeze(0) image_EoC4_2E = image[4].squeeze(0) image_FoC5_2F = image[5].squeeze(0) image_AoC1_2A = unloader(image_AoC1_2A) image_BoC2_2B = unloader(image_BoC2_2B) image_D3Co_2C = unloader(image_D3Co_2C) image_DoC3_2D = unloader(image_DoC3_2D) image_EoC4_2E = unloader(image_EoC4_2E) image_FoC5_2F = unloader(image_FoC5_2F) image_AoC1_2A.save( os.path.join(dir, '{}-AoC1_2A.png'.format(self.global_iter))) image_BoC2_2B.save( os.path.join(dir, '{}-BoC2_2B.png'.format(self.global_iter))) image_D3Co_2C.save( os.path.join(dir, '{}-D3Co_2C.png'.format(self.global_iter))) image_DoC3_2D.save( os.path.join(dir, '{}-DoC3_2D.png'.format(self.global_iter))) image_EoC4_2E.save( os.path.join(dir, '{}-EoC4_2E.png'.format(self.global_iter))) image_FoC5_2F.save( os.path.join(dir, '{}-FoC5_2F.png'.format(self.global_iter))) elif mode == 'combine_unsup': image_A1B2D3E4F5_2C = image[0].squeeze( 0) # remove the fake batch dimension image_A2B3D4E5F1_2N = image[1].squeeze(0) image_A1B2D3E4F5_2C = unloader(image_A1B2D3E4F5_2C) image_A2B3D4E5F1_2N = unloader(image_A2B3D4E5F1_2N) image_A1B2D3E4F5_2C.save( os.path.join(dir, '{}-A1B2D3E4F5_2C.png'.format(self.global_iter))) image_A2B3D4E5F1_2N.save( os.path.join(dir, '{}-A2B3D4E5F1_2N.png'.format(self.global_iter))) def viz_reconstruction(self): # self.net_mode(train=False) x_A = self.gather.data['images'][0][:100] x_A = make_grid(x_A, normalize=True) x_B = self.gather.data['images'][1][:100] x_B = make_grid(x_B, normalize=True) x_C = self.gather.data['images'][2][:100] x_C = make_grid(x_C, normalize=True) x_D = self.gather.data['images'][3][:100] x_D = make_grid(x_D, normalize=True) x_E = self.gather.data['images'][4][:100] x_E = make_grid(x_E, normalize=True) x_F = self.gather.data['images'][5][:100] x_F = make_grid(x_F, normalize=True) x_A_recon = self.gather.data['images'][6][:100] x_A_recon = make_grid(x_A_recon, normalize=True) images = torch.stack([x_A, x_B, x_C, x_D, x_E, x_F, x_A_recon], dim=0).cpu() self.viz.images(images, env=self.viz_name + '_reconstruction', opts=dict(title=str(self.global_iter)), nrow=10) self.save_sample_img(images, 'recon') # self.net_mode(train=True) def viz_combine_recon(self): # self.net_mode(train=False) AoC1_2A = self.gather.data['combine_supimages'][0][:100] AoC1_2A = make_grid(AoC1_2A, normalize=True) BoC2_2B = self.gather.data['combine_supimages'][1][:100] BoC2_2B = make_grid(BoC2_2B, normalize=True) D3Co_2C = self.gather.data['combine_supimages'][2][:100] D3Co_2C = make_grid(D3Co_2C, normalize=True) DoC3_2D = self.gather.data['combine_supimages'][3][:100] DoC3_2D = make_grid(DoC3_2D, normalize=True) EoC4_2E = self.gather.data['combine_supimages'][4][:100] EoC4_2E = make_grid(EoC4_2E, normalize=True) FoC5_2F = self.gather.data['combine_supimages'][5][:100] FoC5_2F = make_grid(FoC5_2F, normalize=True) images = torch.stack( [AoC1_2A, BoC2_2B, D3Co_2C, DoC3_2D, EoC4_2E, FoC5_2F], dim=0).cpu() self.viz.images(images, env=self.viz_name + 'combine_supimages', opts=dict(title=str(self.global_iter)), nrow=10) self.save_sample_img(images, 'combine_sup') def viz_combine_unsuprecon(self): # self.net_mode(train=False) A1B2D3E4F5_2C = self.gather.data['combine_unsupimages'][0][:100] A1B2D3E4F5_2C = make_grid(A1B2D3E4F5_2C, normalize=True) A2B3D4E5F1_2N = self.gather.data['combine_unsupimages'][1][:100] A2B3D4E5F1_2N = make_grid(A2B3D4E5F1_2N, normalize=True) images = torch.stack([A1B2D3E4F5_2C, A2B3D4E5F1_2N], dim=0).cpu() self.viz.images(images, env=self.viz_name + 'combine_unsupimages', opts=dict(title=str(self.global_iter)), nrow=10) self.save_sample_img(images, 'combine_unsup') def viz_combine(self, x): # self.net_mode(train=False) decoder = self.Autoencoder.decoder encoder = self.Autoencoder.encoder z = encoder(x) z_appe = z[:, 0:250, :, :] z_pose = z[:, 250:, :, :] z_rearrange_combine = torch.cat((z_appe[:-1], z_pose[1:]), dim=1) x_rearrange_combine = decoder(z_rearrange_combine) x_rearrange_combine = F.sigmoid(x_rearrange_combine).data x_show = make_grid(x[:-1].data, normalize=True) x_rearrange_combine_show = make_grid(x_rearrange_combine, normalize=True) images = torch.stack([x_show, x_rearrange_combine_show], dim=0).cpu() self.viz.images(images, env=self.viz_name + '_combine', opts=dict(title=str(self.global_iter)), nrow=10) # samples = [] # for i in range(10): # every pair need visualize # x_appe = x[i].unsqueeze(0) # provide appearance # # z_appe = z[i, 0:250, :, :].unsqueeze(0) # provide appearance # x_pose = x[i+1].unsqueeze(0) # provide pose # # z_pose = z[i+1, 250:, :, :].unsqueeze(0) # provide pose # z_combine = torch.cat((z_appe, z_pose), 1) # x_combine = decoder(z_combine) # x_combine = F.sigmoid(x_combine).data # samples.append(x_appe) # samples.append(x_combine) # samples.append(x_pose) # samples = torch.cat(samples, dim=0).cpu() # title = 'combine(iter:{})'.format(self.global_iter) # if self.viz_on: # self.viz.images(samples, env=self.viz_name+'combine', # opts=dict(title=title)) def viz_lines(self): # self.net_mode(train=False) recon_losses = torch.stack(self.gather.data['recon_loss']).cpu() combine_sup_loss = torch.stack( self.gather.data['combine_sup_loss']).cpu() combine_unsup_loss = torch.stack( self.gather.data['combine_unsup_loss']).cpu() iters = torch.Tensor(self.gather.data['iter']) legend = [] for z_j in range(self.z_dim): legend.append('z_{}'.format(z_j)) legend.append('mean') legend.append('total') if self.win_recon is None: self.win_recon = self.viz.line(X=iters, Y=recon_losses, env=self.viz_name + '_lines', opts=dict( width=400, height=400, xlabel='iteration', title='reconsturction loss', )) else: self.win_recon = self.viz.line(X=iters, Y=recon_losses, env=self.viz_name + '_lines', win=self.win_recon, update='append', opts=dict( width=400, height=400, xlabel='iteration', title='reconsturction loss', )) if self.win_combine_sup is None: self.win_combine_sup = self.viz.line( X=iters, Y=combine_sup_loss, env=self.viz_name + '_lines', opts=dict( width=400, height=400, legend=legend[:self.z_dim], xlabel='iteration', title='combine_sup_loss', )) else: self.win_combine_sup = self.viz.line( X=iters, Y=combine_sup_loss, env=self.viz_name + '_lines', win=self.win_combine_sup, update='append', opts=dict( width=400, height=400, legend=legend[:self.z_dim], xlabel='iteration', title='combine_sup_loss', )) if self.win_combine_unsup is None: self.win_combine_unsup = self.viz.line( X=iters, Y=combine_unsup_loss, env=self.viz_name + '_lines', opts=dict( width=400, height=400, legend=legend[:self.z_dim], xlabel='iteration', title='combine_unsup_loss', )) else: self.win_combine_unsup = self.viz.line( X=iters, Y=combine_sup_loss, env=self.viz_name + '_lines', win=self.win_combine_unsup, update='append', opts=dict( width=400, height=400, legend=legend[:self.z_dim], xlabel='iteration', title='combine_unsup_loss', )) def viz_traverse(self, limit=3, inter=2 / 3, loc=-1): self.net_mode(train=False) import random decoder = self.net.decoder encoder = self.net.encoder interpolation = torch.arange(-limit, limit + 0.1, inter) n_dsets = len(self.data_loader.dataset) rand_idx = random.randint(1, n_dsets - 1) random_img = self.data_loader.dataset.__getitem__(rand_idx) random_img = Variable(cuda(random_img, self.use_cuda), volatile=True).unsqueeze(0) random_img_z = encoder(random_img)[:, :self.z_dim] random_z = Variable(cuda(torch.rand(1, self.z_dim), self.use_cuda), volatile=True) if self.dataset == 'dsprites': fixed_idx1 = 87040 # square fixed_idx2 = 332800 # ellipse fixed_idx3 = 578560 # heart fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1) fixed_img1 = Variable(cuda(fixed_img1, self.use_cuda), volatile=True).unsqueeze(0) fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim] fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2) fixed_img2 = Variable(cuda(fixed_img2, self.use_cuda), volatile=True).unsqueeze(0) fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim] fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3) fixed_img3 = Variable(cuda(fixed_img3, self.use_cuda), volatile=True).unsqueeze(0) fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim] Z = { 'fixed_square': fixed_img_z1, 'fixed_ellipse': fixed_img_z2, 'fixed_heart': fixed_img_z3, 'random_img': random_img_z } else: fixed_idx = 0 fixed_img = self.data_loader.dataset.__getitem__(fixed_idx) fixed_img = Variable(cuda(fixed_img, self.use_cuda), volatile=True).unsqueeze(0) fixed_img_z = encoder(fixed_img)[:, :self.z_dim] Z = { 'fixed_img': fixed_img_z, 'random_img': random_img_z, 'random_z': random_z } gifs = [] for key in Z.keys(): z_ori = Z[key] samples = [] for row in range(self.z_dim): if loc != -1 and row != loc: continue z = z_ori.clone() for val in interpolation: z[:, row] = val sample = F.sigmoid(decoder(z)).data samples.append(sample) gifs.append(sample) samples = torch.cat(samples, dim=0).cpu() title = '{}_latent_traversal(iter:{})'.format( key, self.global_iter) if self.viz_on: self.viz.images(samples, env=self.viz_name + '_traverse', opts=dict(title=title), nrow=len(interpolation)) if self.save_output: output_dir = os.path.join(self.output_dir, str(self.global_iter)) os.makedirs(output_dir, exist_ok=True) gifs = torch.cat(gifs) gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2) for i, key in enumerate(Z.keys()): for j, val in enumerate(interpolation): save_image(tensor=gifs[i][j].cpu(), filename=os.path.join( output_dir, '{}_{}.jpg'.format(key, j)), nrow=self.z_dim, pad_value=1) grid2gif(os.path.join(output_dir, key + '*.jpg'), os.path.join(output_dir, key + '.gif'), delay=10) self.net_mode(train=True) def net_mode(self, train): if not isinstance(train, bool): raise ('Only bool type is supported. True or False') if train: self.net.train() else: self.net.eval() def save_checkpoint(self, filename, silent=True): model_states = { 'net': self.net.state_dict(), } optim_states = { 'optim': self.optim.state_dict(), } win_states = { 'recon': self.win_recon, 'kld': self.win_kld, 'mu': self.win_mu, 'var': self.win_var, } states = { 'iter': self.global_iter, 'win_states': win_states, 'model_states': model_states, 'optim_states': optim_states } file_path = os.path.join(self.ckpt_dir, filename) with open(file_path, mode='wb+') as f: torch.save(states, f) if not silent: print("=> saved checkpoint '{}' (iter {})".format( file_path, self.global_iter)) def load_checkpoint(self, filename): file_path = os.path.join(self.ckpt_dir, filename) if os.path.isfile(file_path): checkpoint = torch.load(file_path) self.global_iter = checkpoint['iter'] self.win_recon = checkpoint['win_states']['recon'] self.win_kld = checkpoint['win_states']['kld'] self.win_var = checkpoint['win_states']['var'] self.win_mu = checkpoint['win_states']['mu'] self.net.load_state_dict(checkpoint['model_states']['net']) self.optim.load_state_dict(checkpoint['optim_states']['optim']) print("=> loaded checkpoint '{} (iter {})'".format( file_path, self.global_iter)) else: print("=> no checkpoint found at '{}'".format(file_path))