class Trainer(object): def __init__(self, config): # Images data path & Output path self.dataset = config.dataset self.data_path = config.data_path self.save_path = os.path.join(config.save_path, config.name) # Training settings self.batch_size = config.batch_size self.total_step = config.total_step self.d_steps_per_iter = config.d_steps_per_iter self.g_steps_per_iter = config.g_steps_per_iter self.d_lr = config.d_lr self.g_lr = config.g_lr self.beta1 = config.beta1 self.beta2 = config.beta2 self.inst_noise_sigma = config.inst_noise_sigma self.inst_noise_sigma_iters = config.inst_noise_sigma_iters self.start = 0 # Unless using pre-trained model # Image transforms self.shuffle = config.shuffle self.drop_last = config.drop_last self.resize = config.resize self.imsize = config.imsize self.centercrop = config.centercrop self.centercrop_size = config.centercrop_size self.tanh_scale = config.tanh_scale self.normalize = config.normalize # Step size self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.save_n_images = config.save_n_images self.max_frames_per_gif = config.max_frames_per_gif # Pretrained model self.pretrained_model = config.pretrained_model # Misc self.manual_seed = config.manual_seed self.disable_cuda = config.disable_cuda self.parallel = config.parallel self.dataloader_args = config.dataloader_args # Output paths self.model_weights_path = os.path.join(self.save_path, config.model_weights_dir) self.sample_path = os.path.join(self.save_path, config.sample_dir) # Model hyper-parameters self.adv_loss = config.adv_loss self.z_dim = config.z_dim self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.lambda_gp = config.lambda_gp # Model name self.name = config.name # Create directories if not exist utils.make_folder(self.save_path) utils.make_folder(self.model_weights_path) utils.make_folder(self.sample_path) # Copy files utils.write_config_to_file(config, self.save_path) utils.copy_scripts(self.save_path) # Check for CUDA utils.check_for_CUDA(self) # Make dataloader self.dataloader, self.num_of_classes = utils.make_dataloader( self.batch_size, self.dataset, self.data_path, self.shuffle, self.drop_last, self.dataloader_args, self.resize, self.imsize, self.centercrop, self.centercrop_size) # Data iterator self.data_iter = iter(self.dataloader) # Build G and D self.build_models() # Start with pretrained model (if it exists) if self.pretrained_model != '': utils.load_pretrained_model(self) if self.adv_loss == 'dcgan': self.criterion = nn.BCELoss() def train(self): # Seed np.random.seed(self.manual_seed) random.seed(self.manual_seed) torch.manual_seed(self.manual_seed) # For fast training cudnn.benchmark = True # For BatchNorm self.G.train() self.D.train() # Fixed noise for sampling from G fixed_noise = torch.randn(self.batch_size, self.z_dim, device=self.device) if self.num_of_classes < self.batch_size: fixed_labels = torch.from_numpy( np.tile(np.arange(self.num_of_classes), self.batch_size // self.num_of_classes + 1)[:self.batch_size]).to(self.device) else: fixed_labels = torch.from_numpy(np.arange(self.batch_size)).to( self.device) # For gan loss label = torch.full((self.batch_size, ), 1, device=self.device) ones = torch.full((self.batch_size, ), 1, device=self.device) # Losses file log_file_name = os.path.join(self.save_path, 'log.txt') log_file = open(log_file_name, "wt") # Init start_time = time.time() G_losses = [] D_losses_real = [] D_losses_fake = [] D_losses = [] D_xs = [] D_Gz_trainDs = [] D_Gz_trainGs = [] # Instance noise - make random noise mean (0) and std for injecting inst_noise_mean = torch.full( (self.batch_size, 3, self.imsize, self.imsize), 0, device=self.device) inst_noise_std = torch.full( (self.batch_size, 3, self.imsize, self.imsize), self.inst_noise_sigma, device=self.device) # Start training for self.step in range(self.start, self.total_step): # Instance noise std is linearly annealed from self.inst_noise_sigma to 0 thru self.inst_noise_sigma_iters inst_noise_sigma_curr = 0 if self.step > self.inst_noise_sigma_iters else ( 1 - self.step / self.inst_noise_sigma_iters) * self.inst_noise_sigma inst_noise_std.fill_(inst_noise_sigma_curr) # ================== TRAIN D ================== # for _ in range(self.d_steps_per_iter): # Zero grad self.reset_grad() # TRAIN with REAL # Get real images & real labels real_images, real_labels = self.get_real_samples() # Get D output for real images & real labels inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device) d_out_real = self.D(real_images + inst_noise, real_labels) # Compute D loss with real images & real labels if self.adv_loss == 'hinge': d_loss_real = torch.nn.ReLU()(ones - d_out_real).mean() elif self.adv_loss == 'wgan_gp': d_loss_real = -d_out_real.mean() else: label.fill_(1) d_loss_real = self.criterion(d_out_real, label) # Backward d_loss_real.backward() # TRAIN with FAKE # Create random noise z = torch.randn(self.batch_size, self.z_dim, device=self.device) # Generate fake images for same real labels fake_images = self.G(z, real_labels) # Get D output for fake images & same real labels inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device) d_out_fake = self.D(fake_images.detach() + inst_noise, real_labels) # Compute D loss with fake images & real labels if self.adv_loss == 'hinge': d_loss_fake = torch.nn.ReLU()(ones + d_out_fake).mean() elif self.adv_loss == 'dcgan': label.fill_(0) d_loss_fake = self.criterion(d_out_fake, label) else: d_loss_fake = d_out_fake.mean() # Backward d_loss_fake.backward() # If WGAN_GP, compute GP and add to D loss if self.adv_loss == 'wgan_gp': d_loss_gp = self.lambda_gp * self.compute_gradient_penalty( real_images, real_labels, fake_images.detach()) d_loss_gp.backward() # Optimize self.D_optimizer.step() # ================== TRAIN G ================== # for _ in range(self.g_steps_per_iter): # Zero grad self.reset_grad() # Get real images & real labels (only need real labels) real_images, real_labels = self.get_real_samples() # Create random noise z = torch.randn(self.batch_size, self.z_dim).to(self.device) # Generate fake images for same real labels fake_images = self.G(z, real_labels) # Get D output for fake images & same real labels inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device) g_out_fake = self.D(fake_images + inst_noise, real_labels) # Compute G loss with fake images & real labels if self.adv_loss == 'dcgan': label.fill_(1) g_loss = self.criterion(g_out_fake, label) else: g_loss = -g_out_fake.mean() # Backward + Optimize g_loss.backward() self.G_optimizer.step() # Print out log info if self.step % self.log_step == 0: G_losses.append(g_loss.mean().item()) D_losses_real.append(d_loss_real.mean().item()) D_losses_fake.append(d_loss_fake.mean().item()) D_loss = D_losses_real[-1] + D_losses_fake[-1] if self.adv_loss == 'wgan_gp': D_loss += d_loss_gp.mean().item() D_losses.append(D_loss) D_xs.append(d_out_real.mean().item()) D_Gz_trainDs.append(d_out_fake.mean().item()) D_Gz_trainGs.append(g_out_fake.mean().item()) curr_time = time.time() curr_time_str = datetime.datetime.fromtimestamp( curr_time).strftime('%Y-%m-%d %H:%M:%S') elapsed = str( datetime.timedelta(seconds=(curr_time - start_time))) log = ( "[{}] : Elapsed [{}], Iter [{} / {}], G_loss: {:.4f}, D_loss: {:.4f}, D_loss_real: {:.4f}, D_loss_fake: {:.4f}, D(x): {:.4f}, D(G(z))_trainD: {:.4f}, D(G(z))_trainG: {:.4f}\n" .format(curr_time_str, elapsed, self.step, self.total_step, G_losses[-1], D_losses[-1], D_losses_real[-1], D_losses_fake[-1], D_xs[-1], D_Gz_trainDs[-1], D_Gz_trainGs[-1])) print(log) log_file.write(log) log_file.flush() utils.make_plots(G_losses, D_losses, D_losses_real, D_losses_fake, D_xs, D_Gz_trainDs, D_Gz_trainGs, self.log_step, self.save_path) # Sample images if self.step % self.sample_step == 0: self.G.eval() fake_images = self.G(fixed_noise, fixed_labels) self.G.train() sample_images = utils.denorm( fake_images.detach()[:self.save_n_images]) # Save batch images vutils.save_image( sample_images, os.path.join(self.sample_path, 'fake_{:05d}.png'.format(self.step))) # Save gif utils.make_gif( sample_images[0].cpu().numpy().transpose(1, 2, 0) * 255, self.step, self.sample_path, self.name, max_frames_per_gif=self.max_frames_per_gif) # Save model if self.step % self.model_save_step == 0: utils.save_ckpt(self) def build_models(self): self.G = Generator(self.z_dim, self.g_conv_dim, self.num_of_classes).to(self.device) self.D = Discriminator(self.d_conv_dim, self.num_of_classes).to(self.device) if 'cuda' in self.device.type and self.parallel and torch.cuda.device_count( ) > 1: self.G = nn.DataParallel(self.G) self.D = nn.DataParallel(self.D) # Loss and optimizer # self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.G_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2]) self.D_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2]) # print networks print(self.G) print(self.D) def reset_grad(self): self.G_optimizer.zero_grad() self.D_optimizer.zero_grad() def get_real_samples(self): try: real_images, real_labels = next(self.data_iter) except: self.data_iter = iter(self.dataloader) real_images, real_labels = next(self.data_iter) real_images, real_labels = real_images.to(self.device), real_labels.to( self.device) return real_images, real_labels def compute_gradient_penalty(self, real_images, real_labels, fake_images): # Compute gradient penalty alpha = torch.rand(real_images.size(0), 1, 1, 1).expand_as(real_images).to(device) interpolated = torch.tensor(alpha * real_images + (1 - alpha) * fake_images, requires_grad=True) out = self.D(interpolated, real_labels) exp_grad = torch.ones(out.size()).to(device) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=exp_grad, retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) return d_loss_gp
class Tester(object): def __init__(self, data_loader, config): # Data loader self.data_loader = data_loader # exact model and loss self.model = config.model self.adv_loss = config.adv_loss # Model hyper-parameters self.imsize = config.imsize self.g_num = config.g_num self.z_dim = config.z_dim self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.parallel = config.parallel self.lambda_gp = config.lambda_gp self.total_step = config.total_step self.d_iters = config.d_iters self.batch_size = config.batch_size self.num_workers = config.num_workers self.g_lr = config.g_lr self.d_lr = config.d_lr self.lr_decay = config.lr_decay self.beta1 = config.beta1 self.beta2 = config.beta2 self.pretrained_model = config.pretrained_model self.dataset = config.dataset self.use_tensorboard = config.use_tensorboard self.image_path = config.image_path self.log_path = config.log_path self.model_save_path = config.model_save_path self.sample_path = config.sample_path self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.version = config.version # Path self.log_path = os.path.join(config.log_path, self.version) self.sample_path = os.path.join(config.sample_path, self.version) self.model_save_path = os.path.join(config.model_save_path, self.version) self.test_store_path = os.path.join(config.test_store_path, self.version) self.build_model() if self.use_tensorboard: self.build_tensorboard() # Start with trained model if self.pretrained_model: self.load_pretrained_model() def test(self): for iter in range(500): fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim)) fake_images, _, _ = self.G(fixed_z) fakeimage = np.transpose(var2numpy(fake_images.data), (0, 2, 3, 1)) self.output_fig( fakeimage, os.path.join(self.test_store_path, '{:03d}_image.png'.format(iter + 1))) def output_fig(self, images_array, file_name): plt.figure(figsize=(6, 6), dpi=100) plt.imshow(helper.images_square_grid(images_array)) plt.axis("off") plt.savefig(file_name + '.png', bbox_inches='tight', pad_inches=0) def build_model(self): self.G = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim).cuda() self.D = Discriminator(self.batch_size, self.imsize, self.d_conv_dim).cuda() if self.parallel: self.G = nn.DataParallel(self.G) self.D = nn.DataParallel(self.D) # Loss and optimizer # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.g_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2]) self.c_loss = torch.nn.CrossEntropyLoss() # print networks print(self.G) print(self.D) def build_tensorboard(self): from logger import Logger self.logger = Logger(self.log_path) def load_pretrained_model(self): self.G.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_G.pth'.format(self.pretrained_model)))) self.D.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_D.pth'.format(self.pretrained_model)))) print('loaded trained models (step: {})..!'.format( self.pretrained_model))
class Trainer(object): def __init__(self, data_loader, config): # Data loader self.data_loader = data_loader # exact model and loss self.model = config.model self.adv_loss = config.adv_loss # Model hyper-parameters self.imsize = config.imsize self.imchan = config.imchan self.g_num = config.g_num self.z_dim = config.z_dim self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.parallel = config.parallel self.lambda_gp = config.lambda_gp self.total_step = config.total_step self.d_iters = config.d_iters self.batch_size = config.batch_size self.num_workers = config.num_workers self.g_lr = config.g_lr self.d_lr = config.d_lr self.lr_decay = config.lr_decay self.beta1 = config.beta1 self.beta2 = config.beta2 self.pretrained_model = config.pretrained_model self.dataset = config.dataset self.use_tensorboard = config.use_tensorboard self.image_path = config.image_path self.log_path = config.log_path self.model_save_path = config.model_save_path self.sample_path = config.sample_path self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.version = config.version # Path self.log_path = os.path.join(config.log_path, self.version) self.sample_path = os.path.join(config.sample_path, self.version) self.model_save_path = os.path.join(config.model_save_path, self.version) self.build_model() if self.use_tensorboard: self.build_tensorboard() # Start with trained model if self.pretrained_model: self.load_pretrained_model() def train(self): # Data iterator data_iter = iter(self.data_loader) step_per_epoch = len(self.data_loader) model_save_step = int(self.model_save_step * step_per_epoch) # Fixed input for debugging fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim)) # Start with trained model if self.pretrained_model: start = self.pretrained_model + 1 else: start = 0 # Start time start_time = time.time() for step in range(start, self.total_step): # ================== Train D ================== # self.D.train() self.G.train() try: real_images, _ = next(data_iter) except: data_iter = iter(self.data_loader) real_images, _ = next(data_iter) # Compute loss with real images # dr1, dr2, df1, df2, gf1, gf2 are attention scores real_images = tensor2var(real_images) d_out_real, dr1, dr2 = self.D(real_images) if self.adv_loss == 'wgan-gp': d_loss_real = -torch.mean(d_out_real) elif self.adv_loss == 'hinge': d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean() # apply Gumbel Softmax z = tensor2var(torch.randn(real_images.size(0), self.z_dim)) fake_images, gf1, gf2 = self.G(z) d_out_fake, df1, df2 = self.D(fake_images) if self.adv_loss == 'wgan-gp': d_loss_fake = d_out_fake.mean() elif self.adv_loss == 'hinge': d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean() # Backward + Optimize d_loss = d_loss_real + d_loss_fake self.reset_grad() d_loss.backward() self.d_optimizer.step() if self.adv_loss == 'wgan-gp': # Compute gradient penalty alpha = torch.rand(real_images.size(0), 1, 1, 1).cuda().expand_as(real_images) interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True) out, _, _ = self.D(interpolated) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones( out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # ================== Train G and gumbel ================== # # Create random noise z = tensor2var(torch.randn(real_images.size(0), self.z_dim)) fake_images, _, _ = self.G(z) # Compute loss with fake images g_out_fake, _, _ = self.D(fake_images) # batch x n if self.adv_loss == 'wgan-gp': g_loss_fake = -g_out_fake.mean() elif self.adv_loss == 'hinge': g_loss_fake = -g_out_fake.mean() self.reset_grad() g_loss_fake.backward() self.g_optimizer.step() # Print out log info if (step + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) if self.G.attn2: ave_gamma_l4 = "{:.4f}".format( self.G.attn2.gamma.mean().item()) else: ave_gamma_l4 = "n/a" print("Elapsed [{}], Step [{}/{}], d_out_real: {:.4f}, " " ave_gamma_l3: {:.4f}, ave_gamma_l4: {}".format( elapsed, step + 1, self.total_step, d_loss_real.item(), self.G.attn1.gamma.mean().item(), ave_gamma_l4)) # Sample images if (step + 1) % self.sample_step == 0: fake_images, _, _ = self.G(fixed_z) save_image( denorm(fake_images.data), os.path.join(self.sample_path, '{}_fake.png'.format(step + 1))) if (step + 1) % model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, '{}_G.pth'.format(step + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, '{}_D.pth'.format(step + 1))) def build_model(self): self.G = Generator(batch_size=self.batch_size, image_size=self.imsize, z_dim=self.z_dim, conv_dim=self.g_conv_dim, image_channels=self.imchan).cuda() self.D = Discriminator(batch_size=self.batch_size, image_size=self.imsize, conv_dim=self.d_conv_dim, image_channels=self.imchan).cuda() if self.parallel: self.G = nn.DataParallel(self.G) self.D = nn.DataParallel(self.D) # Loss and optimizer # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.g_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2]) self.c_loss = torch.nn.CrossEntropyLoss() # print networks print(self.G) print(self.D) def build_tensorboard(self): from logger import Logger self.logger = Logger(self.log_path) def load_pretrained_model(self): self.G.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_G.pth'.format(self.pretrained_model)))) self.D.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_D.pth'.format(self.pretrained_model)))) print('loaded trained models (step: {})..!'.format( self.pretrained_model)) def reset_grad(self): self.d_optimizer.zero_grad() self.g_optimizer.zero_grad() def save_sample(self, data_iter): real_images, _ = next(data_iter) save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png'))
class Trainer(object): def __init__(self, data_loader, config): self.data_loader = data_loader # exact model and loss self.model = config.model self.adv_loss = config.adv_loss # Model hyper-parameters self.imsize = config.imsize self.g_num = config.g_num self.z_dim = config.z_dim self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.parallel = config.parallel self.lambda_gp = config.lambda_gp self.total_step = config.total_step self.d_iters = config.d_iters self.batch_size = config.batch_size self.num_workers = config.num_workers self.ge_lr = config.ge_lr self.d_lr = config.d_lr self.lr_decay = config.lr_decay self.beta1 = config.beta1 self.beta2 = config.beta2 self.pretrained_model = config.pretrained_model self.dataset = config.dataset self.mura_class = config.mura_class self.mura_type = config.mura_type self.use_tensorboard = config.use_tensorboard self.image_path = config.image_path self.log_path = config.log_path self.model_save_path = config.model_save_path self.sample_path = config.sample_path self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.version = config.version # Path self.log_path = os.path.join(config.log_path, self.version) self.sample_path = os.path.join(config.sample_path, self.version) self.model_save_path = os.path.join(config.model_save_path, self.version) if self.use_tensorboard: self.build_tensorboard() self.build_model() # Start with trained model if self.pretrained_model: self.load_pretrained_model() def train(self): # Data iterator print('inside the train') data_iter = iter(self.data_loader) step_per_epoch = len(self.data_loader) model_save_step = int(self.model_save_step * step_per_epoch) # Fixed input for debugging fixed_img, _ = next(data_iter) fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim)) if self.use_tensorboard: self.writer.add_image('img/fixed_img', denorm(fixed_img.data), 0) else: save_image(denorm(fixed_img.data), os.path.join(self.sample_path, 'fixed_img.png')) # Start with trained model if self.pretrained_model: start = self.pretrained_model + 1 else: start = 0 self.D.train() self.E.train() self.G.train() # Start time start_time = time.time() for step in range(start, self.total_step): self.reset_grad() # Sample from data and prior try: real_images, _ = next(data_iter) except: data_iter = iter(self.data_loader) real_images, _ = next(data_iter) real_images = tensor2var(real_images) fake_z = tensor2var(torch.randn(real_images.size(0), self.z_dim)) noise1 = torch.Tensor(real_images.size()).normal_( 0, 0.01 * (step + 1 - self.total_step) / (step + 1)) noise2 = torch.Tensor(real_images.size()).normal_( 0, 0.01 * (step + 1 - self.total_step) / (step + 1)) # Sample from condition real_z, _, _ = self.E(real_images) fake_images, gf1, gf2 = self.G(fake_z) dr, dr5, dr4, dr3, drz, dra2, dra1 = self.D( real_images + noise1, real_z) df, df5, df4, df3, dfz, dfa2, dfa1 = self.D( fake_images + noise2, fake_z) # Compute loss with real and fake images # dr1, dr2, df1, df2, gf1, gf2 are attention scores if self.adv_loss == 'wgan-gp': d_loss_real = -torch.mean(dr) d_loss_fake = df.mean() g_loss_fake = -df.mean() e_loss_real = -dr.mean() elif self.adv_loss == 'hinge1': d_loss_real = torch.nn.ReLU()(1.0 - dr).mean() d_loss_fake = torch.nn.ReLU()(1.0 + df).mean() g_loss_fake = -df.mean() e_loss_real = -dr.mean() elif self.adv_loss == 'hinge': d_loss_real = -log(dr).mean() d_loss_fake = -log(1.0 - df).mean() g_loss_fake = -log(df).mean() e_loss_real = -log(1.0 - dr).mean() elif self.adv_loss == 'inverse': d_loss_real = -log(1.0 - dr).mean() d_loss_fake = -log(df).mean() g_loss_fake = -log(1.0 - df).mean() e_loss_real = -log(dr).mean() # ================== Train D ================== # d_loss = d_loss_real + d_loss_fake d_loss.backward(retain_graph=True) self.d_optimizer.step() if self.adv_loss == 'wgan-gp': # Compute gradient penalty alpha = torch.rand(real_images.size(0), 1, 1, 1).expand_as(real_images) interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True) out, _, _ = self.D(interpolated) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones(out.size()), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp d_loss.backward() self.d_optimizer.step() # ================== Train G and E ================== # ge_loss = g_loss_fake + e_loss_real ge_loss.backward() self.ge_optimizer.step() # Print out log info if (step + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print( f"Elapsed: [{elapsed}], step: [{step+1}/{self.total_step}], d_loss: {d_loss}, ge_loss: {ge_loss}" ) if self.use_tensorboard: self.writer.add_scalar('d/loss_real', d_loss_real.data, step + 1) self.writer.add_scalar('d/loss_fake', d_loss_fake.data, step + 1) self.writer.add_scalar('d/loss', d_loss.data, step + 1) self.writer.add_scalar('ge/loss_real', e_loss_real.data, step + 1) self.writer.add_scalar('ge/loss_fake', g_loss_fake.data, step + 1) self.writer.add_scalar('ge/loss', ge_loss.data, step + 1) self.writer.add_scalar('ave_gamma/l3', self.G.attn1.gamma.mean().data, step + 1) self.writer.add_scalar('ave_gamma/l4', self.G.attn2.gamma.mean().data, step + 1) # Sample images if (step + 1) % self.sample_step == 0: img_from_z, _, _ = self.G(fixed_z) z_from_img, _, _ = self.E(tensor2var(fixed_img)) reimg_from_z, _, _ = self.G(z_from_img) if self.use_tensorboard: self.writer.add_image('img/reimg_from_z', denorm(reimg_from_z.data), step + 1) self.writer.add_image('img/img_from_z', denorm(img_from_z.data), step + 1) else: save_image( denorm(img_from_z.data), os.path.join(self.sample_path, '{}_img_from_z.png'.format(step + 1))) save_image( denorm(reimg_from_z.data), os.path.join(self.sample_path, '{}_reimg_from_z.png'.format(step + 1))) if (step + 1) % model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, '{}_G.pth'.format(step + 1))) torch.save( self.E.state_dict(), os.path.join(self.model_save_path, '{}_E.pth'.format(step + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, '{}_D.pth'.format(step + 1))) def build_model(self): self.G = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim) self.E = Encoder(self.batch_size, self.imsize, self.z_dim, self.d_conv_dim) self.D = Discriminator(self.batch_size, self.imsize, self.z_dim, self.d_conv_dim) if self.parallel: self.G = nn.DataParallel(self.G) self.E = nn.DataParallel(self.E) self.D = nn.DataParallel(self.D) # Loss and optimizer self.ge_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, itertools.chain(self.G.parameters(), self.E.parameters())), self.ge_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2]) self.c_loss = torch.nn.CrossEntropyLoss() # print networks # print(self.G) # print(self.E) # print(self.D) def build_tensorboard(self): '''Initialize tensorboard writeri''' self.writer = SummaryWriter(self.log_path) def load_pretrained_model(self): self.G.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_G.pth'.format(self.pretrained_model)))) self.E.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_E.pth'.format(self.pretrained_model)))) self.D.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_D.pth'.format(self.pretrained_model)))) print('loaded trained models (step: {})..!'.format( self.pretrained_model)) def reset_grad(self): self.d_optimizer.zero_grad() self.ge_optimizer.zero_grad() def save_sample(self, data_iter): real_images, _ = next(data_iter) save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png'))
class Trainer(object): def __init__(self, data_loader, config): torch.manual_seed(config.seed) torch.cuda.manual_seed(config.seed) self.data_loader = data_loader self.model = config.model self.adv_loss = config.adv_loss self.imsize = config.imsize self.g_num = config.g_num self.z_dim = config.z_dim self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.parallel = config.parallel self.extra = config.extra self.lambda_gp = config.lambda_gp self.total_step = config.total_step self.d_iters = config.d_iters self.batch_size = config.batch_size self.num_workers = config.num_workers self.g_lr = config.g_lr self.d_lr = config.d_lr self.lr_scheduler = config.lr_scheduler self.g_beta1 = config.g_beta1 self.d_beta1 = config.d_beta1 self.beta2 = config.beta2 self.dataset = config.dataset self.log_path = config.log_path self.model_save_path = config.model_save_path self.sample_path = config.sample_path self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.version = config.version self.backup_freq = config.backup_freq self.bup_path = config.bup_path # Path self.optim = config.optim self.svrg = config.svrg self.avg_start = config.avg_start self.build_model() if self.svrg: self.mu_g = [] self.mu_d = [] self.g_snapshot = copy.deepcopy(self.G) self.d_snapshot = copy.deepcopy(self.D) self.svrg_freq_sampler = bernoulli.Bernoulli(torch.tensor([1 / len(self.data_loader)])) self.info_logger = setup_logger(self.log_path) self.info_logger.info(config) self.cont = config.cont def train(self): self.data_gen = self._data_gen() fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim)) if self.cont: start = self.load_backup() else: start = 0 start_time = time.time() if self.svrg: self.update_svrg_stats() for step in range(start, self.total_step): # =================== SVRG =================== # if self.svrg and self.svrg_freq_sampler.sample() == 1: # ================= Update Avg ================= # if self.avg_start >= 0 and step > 0 and step >= self.avg_start: self.update_avg_nets() if self.avg_freq_restart_sampler.sample() == 1: self.G.load_state_dict(self.avg_g.state_dict()) self.D.load_state_dict(self.avg_d.state_dict()) self.avg_step = 1 self.info_logger.info('Params updated with avg-nets at %d-th step.' % step) self.update_svrg_stats() self.info_logger.info("SVRG stats updated at %d-th step." % step) # ================= Train pair ================= # d_loss_real = self._update_pair(step) # --- storing stuff --- if (step + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print("Elapsed [{}], Step [{}/{}]".format(elapsed, step + 1, self.total_step)) if (step + 1) % self.sample_step == 0: save_image(denorm(self.G(fixed_z).data), os.path.join(self.sample_path, 'gen', 'iter%08d.png' % step)) save_image(denorm(self.G_avg(fixed_z).data), os.path.join(self.sample_path, 'gen_avg', 'iter%08d.png' % step)) save_image(denorm(self.G_ema(fixed_z).data), os.path.join(self.sample_path, 'gen_ema', 'iter%08d.png' % step)) if self.model_save_step > 0 and (step+1) % self.model_save_step == 0: torch.save(self.G.state_dict(), os.path.join(self.model_save_path, 'gen', 'iter%08d.pth' % step)) torch.save(self.G_avg.state_dict(), os.path.join(self.model_save_path, 'gen_avg', 'iter%08d.pth' % step)) torch.save(self.G_ema.state_dict(), os.path.join(self.model_save_path, 'gen_ema', 'iter%08d.pth' % step)) if self.backup_freq > 0 and (step+1) % self.backup_freq == 0: self.backup(step) def _data_gen(self): """ Data iterator :return: s """ data_iter = iter(self.data_loader) while True: try: real_images, _ = next(data_iter) except StopIteration: data_iter = iter(self.data_loader) real_images, _ = next(data_iter) yield real_images def _update_pair(self, step): _lr_scheduler = self.lr_scheduler > 0 and step > 0 and step % len(self.data_loader) == 0 self.D.train() self.G.train() real_images = tensor2var(next(self.data_gen)) self._extra_sync_nets() if self.extra: # ================== Train D @ t + 1/2 ================== # self._backprop_disc(D=self.D_extra, G=self.G, real_images=real_images, d_optim=self.d_optimizer_extra, svrg=self.svrg, scheduler_d=self.scheduler_d_extra if _lr_scheduler else None) # ================== Train G @ t + 1/2 ================== # self._backprop_gen(G=self.G_extra, D=self.D, bsize=real_images.size(0), g_optim=self.g_optimizer_extra, svrg=self.svrg, scheduler_g=self.scheduler_g_extra if _lr_scheduler else None) real_images = tensor2var(next(self.data_gen)) # Re-sample # ================== Train D @ t + 1 ================== # d_loss_real = self._backprop_disc(G=self.G_extra, D=self.D, real_images=real_images, d_optim=self.d_optimizer, svrg=self.svrg, scheduler_d=self.scheduler_d if _lr_scheduler else None) # ================== Train G and gumbel @ t + 1 ================== # self._backprop_gen(G=self.G, D=self.D_extra, bsize=real_images.size(0), g_optim=self.g_optimizer, svrg=self.svrg, scheduler_g=self.scheduler_g if _lr_scheduler else None) # === Moving avg Generator-nets === self._update_avg_gen(step) self._update_ema_gen() return d_loss_real def _normalize_acc_grads(self, net): """Divides accumulated gradients with len(self.data_loader)""" for _param in filter(lambda p: p.requires_grad, net.parameters()): _param.grad.data.div_(len(self.data_loader)) def update_svrg_stats(self): self.mu_g, self.mu_d = [], [] # Update mu_d #################### self.d_optimizer.zero_grad() for _, _data in enumerate(self.data_loader): real_images = tensor2var(_data[0]) self._backprop_disc(self.G, self.D, real_images, d_optim=None, svrg=False) self._normalize_acc_grads(self.D) for _param in filter(lambda p: p.requires_grad, self.D.parameters()): self.mu_d.append(_param.grad.data.clone()) # Update mu_g #################### self.g_optimizer.zero_grad() for _ in range(len(self.data_loader)): self._backprop_gen(self.G, self.D, self.batch_size, g_optim=None, svrg=False) self._normalize_acc_grads(self.G) for _param in filter(lambda p: p.requires_grad, self.G.parameters()): self.mu_g.append(_param.grad.data.clone()) # Update snapshots ############### self.g_snapshot.load_state_dict(self.G.state_dict()) self.d_snapshot.load_state_dict(self.D.state_dict()) @staticmethod def _update_grads_svrg(params, snapshot_params, mu): """Helper function which updates the accumulated gradients of params by subtracting those of snapshot and adding mu. Operates in-place. See line 12 & 14 of Algo. 3 in SVRG-GAN. Raises: ValueError if the inputs have different lengths (the length corresponds to the number of layers in the network) :param params: [list of torch.nn.parameter.Parameter] :param snapshot_params: [torch.nn.parameter.Parameter] :param mu: [list of torch(.cuda).FloatTensor] :return: [None] """ if not len(params) == len(snapshot_params) == len(mu): raise ValueError("Expected input of identical length. " "Got {}, {}, {}".format(len(params), len(snapshot_params), len(mu))) for i in range(len(mu)): params[i].grad.data.sub_(snapshot_params[i].grad.data) params[i].grad.data.add_(mu[i]) def _backprop_disc(self, G, D, real_images, d_optim=None, svrg=False, scheduler_d=None): """Updates D (Vs. G). :param G: :param D: :param real_images: :param d_optim: if None, only backprop :param svrg: :return: """ d_out_real = D(real_images) if self.adv_loss == 'wgan-gp': d_loss_real = - torch.mean(d_out_real) elif self.adv_loss == 'hinge': d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean() else: raise NotImplementedError z = tensor2var(torch.randn(real_images.size(0), self.z_dim)) fake_images = G(z) d_out_fake = D(fake_images) if self.adv_loss == 'wgan-gp': d_loss_fake = d_out_fake.mean() elif self.adv_loss == 'hinge': d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean() else: raise NotImplementedError # Backward + Optimize d_loss = d_loss_real + d_loss_fake if d_optim is not None: d_optim.zero_grad() d_loss.backward() if d_optim is not None: if svrg: # d_snapshot Vs. g_snapshot d_out_real = self.d_snapshot(real_images) d_out_fake = self.d_snapshot(self.g_snapshot(z)) if self.adv_loss == 'wgan-gp': d_s_loss_real = - torch.mean(d_out_real) d_loss_fake = d_out_fake.mean() elif self.adv_loss == 'hinge': d_s_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean() d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean() else: raise NotImplementedError d_loss = d_s_loss_real + d_loss_fake self.d_snapshot.zero_grad() d_loss.backward() self._update_grads_svrg(list(filter(lambda p: p.requires_grad, D.parameters())), list(filter(lambda p: p.requires_grad, self.d_snapshot.parameters())), self.mu_d) d_optim.step() if scheduler_d is not None: scheduler_d.step() if self.adv_loss == 'wgan-gp': # Todo: add SVRG for wgan-gp raise NotImplementedError('SVRG-WGAN-gp is not implemented yet') # Compute gradient penalty alpha = torch.rand(real_images.size(0), 1, 1, 1).cuda().expand_as(real_images) interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True) out = D(interpolated) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones(out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1) ** 2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp if d_optim is not None: d_optim.reset_grad() d_loss.backward() if d_optim is not None: self.d_optimizer.step() return d_loss_real.data.item() def _backprop_gen(self, G, D, bsize, g_optim=True, svrg=False, scheduler_g=None): """Updates G (Vs. D). :param G: :param D: :param bsize: :param g_optim: if None only backprop :param svrg: :return: """ z = tensor2var(torch.randn(bsize, self.z_dim)) fake_images = G(z) g_out_fake = D(fake_images) # batch x n if self.adv_loss == 'wgan-gp' or self.adv_loss == 'hinge': g_loss_fake = - g_out_fake.mean() if g_optim is not None: g_optim.zero_grad() g_loss_fake.backward() if g_optim is not None: if svrg: # G_snapshot Vs. D_snapshot self.g_snapshot.zero_grad() if self.adv_loss == 'wgan-gp' or self.adv_loss == 'hinge': (- self.d_snapshot(self.g_snapshot(z)).mean()).backward() else: raise NotImplementedError self._update_grads_svrg(list(filter(lambda p: p.requires_grad, G.parameters())), list(filter(lambda p: p.requires_grad, self.g_snapshot.parameters())), self.mu_g) g_optim.step() if scheduler_g is not None: scheduler_g.step() return g_loss_fake.data.item() def build_model(self): # Models ################################################################### self.G = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim).cuda() self.D = Discriminator(self.batch_size, self.imsize, self.d_conv_dim).cuda() # Todo: do not allocate unnecessary GPU mem for G_extra and D_extra if self.extra == False self.G_extra = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim).cuda() self.D_extra = Discriminator(self.batch_size, self.imsize, self.d_conv_dim).cuda() if self.avg_start >= 0: self.avg_g = copy.deepcopy(self.G) self.avg_d = copy.deepcopy(self.D) self._requires_grad(self.avg_g, False) self._requires_grad(self.avg_d, False) self.avg_g.eval() self.avg_d.eval() self.avg_step = 1 self.avg_freq_restart_sampler = bernoulli.Bernoulli(.1) if self.parallel: self.G = nn.DataParallel(self.G) self.D = nn.DataParallel(self.D) self.G_extra = nn.DataParallel(self.G_extra) self.D_extra = nn.DataParallel(self.D_extra) if self.avg_start >= 0: self.avg_g = nn.DataParallel(self.avg_g) self.avg_d = nn.DataParallel(self.avg_d) self.G_extra.train() self.D_extra.train() self.G_avg = copy.deepcopy(self.G) self.G_ema = copy.deepcopy(self.G) self._requires_grad(self.G_avg, False) self._requires_grad(self.G_ema, False) # Logs, Loss & optimizers ################################################################### grad_var_logger_g = setup_logger(self.log_path, 'log_grad_var_g.log') grad_var_logger_d = setup_logger(self.log_path, 'log_grad_var_d.log') grad_mean_logger_g = setup_logger(self.log_path, 'log_grad_mean_g.log') grad_mean_logger_d = setup_logger(self.log_path, 'log_grad_mean_d.log') if self.optim == 'sgd': self.g_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, logger_mean=grad_mean_logger_g, logger_var=grad_var_logger_g) self.d_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, logger_mean=grad_mean_logger_d, logger_var=grad_var_logger_d) self.g_optimizer_extra = torch.optim.SGD(filter(lambda p: p.requires_grad, self.G_extra.parameters()), self.g_lr) self.d_optimizer_extra = torch.optim.SGD(filter(lambda p: p.requires_grad, self.D_extra.parameters()), self.d_lr) elif self.optim == 'adam': self.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.g_beta1, self.beta2], logger_mean=grad_mean_logger_g, logger_var=grad_var_logger_g) self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.d_beta1, self.beta2], logger_mean=grad_mean_logger_d, logger_var=grad_var_logger_d) self.g_optimizer_extra = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G_extra.parameters()), self.g_lr, [self.g_beta1, self.beta2]) self.d_optimizer_extra = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D_extra.parameters()), self.d_lr, [self.d_beta1, self.beta2]) elif self.optim == 'svrgadam': self.g_optimizer = torch.optim.SvrgAdam(filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.g_beta1, self.beta2], logger_mean=grad_mean_logger_g, logger_var=grad_var_logger_g) self.d_optimizer = torch.optim.SvrgAdam(filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.d_beta1, self.beta2], logger_mean=grad_mean_logger_d, logger_var=grad_var_logger_d) self.g_optimizer_extra = torch.optim.SvrgAdam(filter(lambda p: p.requires_grad, self.G_extra.parameters()), self.g_lr, [self.g_beta1, self.beta2]) self.d_optimizer_extra = torch.optim.SvrgAdam(filter(lambda p: p.requires_grad, self.D_extra.parameters()), self.d_lr, [self.d_beta1, self.beta2]) else: raise NotImplementedError('Supported optimizers: SGD, Adam, Adadelta') if self.lr_scheduler > 0: # Exponentially decaying learning rate self.scheduler_g = torch.optim.lr_scheduler.ExponentialLR(self.g_optimizer, gamma=self.lr_scheduler) self.scheduler_d = torch.optim.lr_scheduler.ExponentialLR(self.d_optimizer, gamma=self.lr_scheduler) self.scheduler_g_extra = torch.optim.lr_scheduler.ExponentialLR(self.g_optimizer_extra, gamma=self.lr_scheduler) self.scheduler_d_extra = torch.optim.lr_scheduler.ExponentialLR(self.d_optimizer_extra, gamma=self.lr_scheduler) print(self.G) print(self.D) def _extra_sync_nets(self): """ Helper function. Copies the current parameters to the t+1/2 parameters, stored as 'net' and 'extra_net', respectively. :return: [None] """ self.G_extra.load_state_dict(self.G.state_dict()) self.D_extra.load_state_dict(self.D.state_dict()) @staticmethod def _update_avg(avg_net, net, avg_step): """Updates average network.""" # Todo: input val net_param = list(net.parameters()) for i, p in enumerate(avg_net.parameters()): p.mul_((avg_step - 1) / avg_step) p.add_(net_param[i].div(avg_step)) @staticmethod def _requires_grad(_net, _bool=True): """Helper function which sets the requires_grad of _net to _bool. Raises: TypeError: _net is given but is not derived from nn.Module, or _bool is not boolean :param _net: [nn.Module] :param _bool: [bool, optional] Default: True :return: [None] """ if _net and not isinstance(_net, torch.nn.Module): raise TypeError("Expected torch.nn.Module. Got: {}".format(type(_net))) if not isinstance(_bool, bool): raise TypeError("Expected bool. Got: {}".format(type(_bool))) if _net is not None: for _w in _net.parameters(): _w.requires_grad = _bool def update_avg_nets(self): self._update_avg(self.avg_g, self.G, self.avg_step) self._update_avg(self.avg_d, self.D, self.avg_step) self.avg_step += 1 def save_sample(self, data_iter): real_images, _ = next(data_iter) save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png')) def backup(self, iteration): """Back-ups the networks & optimizers' states. Note: self.g_extra & self.d_extra are not stored, as these are copied from self.G & self.D at the beginning of each iteration. However, the optimizers are backed up. :param iteration: [int] :return: [None] """ torch.save(self.G.state_dict(), os.path.join(self.bup_path, 'gen.pth')) torch.save(self.D.state_dict(), os.path.join(self.bup_path, 'disc.pth')) torch.save(self.G_avg.state_dict(), os.path.join(self.bup_path, 'gen_avg.pth')) torch.save(self.G_ema.state_dict(), os.path.join(self.bup_path, 'gen_ema.pth')) torch.save(self.g_optimizer.state_dict(), os.path.join(self.bup_path, 'gen_optim.pth')) torch.save(self.d_optimizer.state_dict(), os.path.join(self.bup_path, 'disc_optim.pth')) torch.save(self.g_optimizer_extra.state_dict(), os.path.join(self.bup_path, 'gen_extra_optim.pth')) torch.save(self.d_optimizer_extra.state_dict(), os.path.join(self.bup_path, 'disc_extra_optim.pth')) with open(os.path.join(self.bup_path, "timestamp.txt"), "w") as fff: fff.write("%d" % iteration) def load_backup(self): """Loads the Backed-up networks & optimizers' states. Note: self.g_extra & self.d_extra are not stored, as these are copied from self.G & self.D at the beginning of each iteration. However, the optimizers are backed up. :return: [int] timestamp to continue from """ if not os.path.exists(self.bup_path): raise ValueError('Cannot load back-up. Directory {} ' 'does not exist.'.format(self.bup_path)) self.G.load_state_dict(torch.load(os.path.join(self.bup_path, 'gen.pth'))) self.D.load_state_dict(torch.load(os.path.join(self.bup_path, 'disc.pth'))) self.G_avg.load_state_dict(torch.load(os.path.join(self.bup_path, 'gen_avg.pth'))) self.G_ema.load_state_dict(torch.load(os.path.join(self.bup_path, 'gen_ema.pth'))) self.g_optimizer.load_state_dict(torch.load(os.path.join(self.bup_path, 'gen_optim.pth'))) self.d_optimizer.load_state_dict(torch.load(os.path.join(self.bup_path, 'disc_optim.pth'))) self.g_optimizer_extra.load_state_dict(torch.load(os.path.join(self.bup_path, 'gen_extra_optim.pth'))) self.d_optimizer_extra.load_state_dict(torch.load(os.path.join(self.bup_path, 'disc_extra_optim.pth'))) with open(os.path.join(self.bup_path, "timestamp.txt"), "r") as fff: timestamp = [int(x) for x in next(fff).split()] # read first line if not len(timestamp) == 1: raise ValueError('Could not determine timestamp of the backed-up models.') timestamp = int(timestamp[0]) + 1 self.info_logger.info("Loaded models from %s, at timestamp %d." % (self.bup_path, timestamp)) return timestamp def _update_avg_gen(self, n_gen_update): """ Updates the uniform average generator. """ l_param = list(self.G.parameters()) l_avg_param = list(self.G_avg.parameters()) if len(l_param) != len(l_avg_param): raise ValueError("Got different lengths: {}, {}".format(len(l_param), len(l_avg_param))) for i in range(len(l_param)): l_avg_param[i].data.copy_(l_avg_param[i].data.mul(n_gen_update).div(n_gen_update + 1.).add( l_param[i].data.div(n_gen_update + 1.))) def _update_ema_gen(self, beta_ema=0.9999): """ Updates the exponential moving average generator. """ l_param = list(self.G.parameters()) l_ema_param = list(self.G_ema.parameters()) if len(l_param) != len(l_ema_param): raise ValueError("Got different lengths: {}, {}".format(len(l_param), len(l_ema_param))) for i in range(len(l_param)): l_ema_param[i].data.copy_(l_ema_param[i].data.mul(beta_ema).add( l_param[i].data.mul(1-beta_ema)))
class Trainer(object): def __init__(self, data_loader, config): # Data loader self.data_loader = data_loader # exact model and loss self.model = config.model self.adv_loss = config.adv_loss self.conv_G = config.conv_G # Model hyper-parameters self.imsize = config.imsize self.g_num = config.g_num self.z_dim = config.z_dim self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.parallel = config.parallel self.lambda_gp = config.lambda_gp self.total_step = config.total_step self.d_iters = config.d_iters self.batch_size = config.batch_size self.num_workers = config.num_workers self.g_lr = config.g_lr self.d_lr = config.d_lr self.lr_decay = config.lr_decay self.beta1 = config.beta1 self.beta2 = config.beta2 self.pretrained_model = config.pretrained_model self.dataset = config.dataset self.use_tensorboard = config.use_tensorboard self.image_path = config.image_path self.log_path = config.log_path self.model_save_path = config.model_save_path self.sample_path = config.sample_path self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.version = config.version # Path self.log_path = os.path.join(config.log_path, self.version) self.sample_path = os.path.join(config.sample_path, self.version) self.model_save_path = os.path.join(config.model_save_path, self.version) self.cuda = torch.cuda.is_available() #and cuda print("Using cuda:", self.cuda) self.build_model() #self.use_tensorboard = True if self.use_tensorboard: self.build_tensorboard() # Start with trained model if self.pretrained_model: self.load_pretrained_model() def train(self): # Data iterator data_iter = iter(self.data_loader) step_per_epoch = len(self.data_loader) model_save_step = int(self.model_save_step * step_per_epoch) # Fixed input for debugging fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim)) # Start with trained model if self.pretrained_model: start = self.pretrained_model + 1 else: start = 0 # Start time start_time = time.time() for step in range(start, self.total_step): # ================== Train D ================== # self.D.train() self.G.train() try: items = next(data_iter) except: data_iter = iter(self.data_loader) items = next(data_iter) X, Y = items fake_class = torch.Tensor(np.ones(Y.shape)* np.random.randint(0, 6, size=(Y.shape[0], 1, 1, 1))) X, Y = X.type(torch.FloatTensor), Y.type(torch.FloatTensor) #X, Y = Variable(X.cuda()), Variable(Y.cuda()) X, Y = Variable(X), Variable(Y) if self.cuda: X = X.cuda() Y = Y.cuda() class_label = Y[:,0,0,0] #class_one_hot = torch.zeros(Y.shape[0], 6) #for i, elem in enumerate(class_label): #class_one_hot[i, int(elem.item())] = 1.0 class_one_hot = class_label.type(torch.LongTensor) #FRITS: the real_disc_in consists of the images X and the desired class #desired class chosen randomly, different from real class Y real_disc_in = X#,torch.cat((X,Y), dim = 1) generator_in = torch.cat((X,fake_class), dim = 1) # Compute loss with real images # dr1, dr2, df1, df2, gf1, gf2 are attention scores #real_images = tensor2var(real_images) #Frits TODO: why feed the real_disc_in to D? d_out_real,dr1,dr2 = self.D(real_disc_in) print(d_out_real) print(class_one_hot) if self.adv_loss == 'wgan-gp': d_loss_real = - torch.mean(d_out_real) elif self.adv_loss == 'hinge': d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean() elif self.adv_loss == 'softmax': d_loss_real = F.cross_entropy(d_out_real, class_one_hot).mean() # apply Gumbel Softmax #Changed to input both image and class fake_images,gf1,gf2 = self.G(torch.cat((X,Y), dim = 1)) fake_disc_in = fake_images#torch.cat((fake_images, Y), dim = 1) d_out_fake,df1,df2 = self.D(fake_disc_in) # if self.adv_loss == 'wgan-gp': # d_loss_fake = d_out_fake.mean() if self.adv_loss == 'hinge': d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean() elif self.adv_loss == 'softmax': d_loss_fake = F.cross_entropy(d_out_fake, class_one_hot).mean() #elif self.adv_loss == 'softmax': # Backward + Optimize d_loss = d_loss_real + d_loss_fake self.reset_grad() d_loss.backward() self.d_optimizer.step() # ================== Train G and gumbel ================== # # Create random noise #z = tensor2var(torch.randn(real_images.size(0), self.z_dim)) #TODO Fritz: Do we need this? fake_images,_,_ = self.G(generator_in) fake_disc_in = fake_images#torch.cat((fake_images, Y), dim = 1) # Compute loss with fake images g_out_fake,_,_ = self.D(fake_disc_in) # batch x n if self.adv_loss == 'wgan-gp': g_loss_fake = - g_out_fake.mean() elif self.adv_loss == 'hinge': g_loss_fake = - g_out_fake.mean() elif self.adv_loss == 'softmax': g_loss_fake = - F.cross_entropy(g_out_fake, class_one_hot).mean() self.reset_grad() g_loss_fake.backward() self.g_optimizer.step() # Print out log info if (step + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_loss: {:.4f}, g_loss {:.4f}" " ave_gamma_l3: {:.4f}, ave_gamma_l4: {:.4f}". format(elapsed, step + 1, self.total_step, (step + 1), self.total_step , d_loss.item(), g_loss_fake.item(), self.G.attn1.gamma.mean().item(), self.G.attn2.gamma.mean().item() )) # format(elapsed, step + 1, self.total_step, (step + 1), # self.total_step , d_loss.data[0], g_loss_fake.data[0], # self.G.attn1.gamma.mean().data[0], self.G.attn2.gamma.mean().data[0] )) with open('log_info.txt', 'a') as f: # f.write("Step {}, D Loss {}, G Loss {}\n".format(step + 1, d_loss.data[0], g_loss_fake.data[0])) f.write("Step {}, D Loss {}, G Loss {}\n".format(step + 1, d_loss.item(), g_loss_fake.item())) # Sample images if (step + 1) % self.sample_step == 0: fake_images,_,_= self.G(generator_in) result = torch.cat((X, fake_images, Y), dim = 2) save_image(denorm(result.data), os.path.join(self.sample_path, '{}_fake.png'.format(step + 1))) with open(os.path.join(self.sample_path,'step_{}.txt'.format(step+1)), 'a') as f: # f.write("Step {}, D Loss {}, G Loss {}\n".format(step + 1, d_loss.data[0], g_loss_fake.data[0])) real_labels = Y[:, 0, 0, 0] fake_labels = fake_class[:, 0, 0, 0] print(fake_labels) f.write("Step {}, Real Labels: {}, Target(Fake) Labels {}\n".format(step + 1, real_labels, fake_labels)) if (step+1) % model_save_step==0: torch.save(self.G.state_dict(), os.path.join(self.model_save_path, '{}_G.pth'.format(step + 1))) torch.save(self.D.state_dict(), os.path.join(self.model_save_path, '{}_D.pth'.format(step + 1))) def build_model(self): self.G = None if self.conv_G: self.G = UpDownConvolutionalGenerator(self.batch_size,self.imsize, self.z_dim, self.g_conv_dim) if self.cuda: self.G = self.G.cuda() else: self.G = Generator(self.batch_size,self.imsize, self.z_dim, self.g_conv_dim) if self.cuda: self.G = self.G.cuda() self.D = Discriminator(self.batch_size,self.imsize, self.d_conv_dim) if self.cuda: self.D = self.D.cuda() if self.parallel: self.G = nn.DataParallel(self.G) self.D = nn.DataParallel(self.D) # Loss and optimizer # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2]) self.c_loss = torch.nn.CrossEntropyLoss() # print networks print(self.G) print(self.D) def build_tensorboard(self): from logger import Logger self.logger = Logger(self.log_path) def load_pretrained_model(self): self.G.load_state_dict(torch.load(os.path.join( self.model_save_path, '{}_G.pth'.format(self.pretrained_model)))) self.D.load_state_dict(torch.load(os.path.join( self.model_save_path, '{}_D.pth'.format(self.pretrained_model)))) print('loaded trained models (step: {})..!'.format(self.pretrained_model)) def reset_grad(self): self.d_optimizer.zero_grad() self.g_optimizer.zero_grad() def save_sample(self, data_iter): real_images, _ = next(data_iter) save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png'))
class Tester(object): def __init__(self, data_loader, args, train_loader, model_decoder, chamfer, vis_Valida): # decoder settings self.model_decoder = model_decoder self.chamfer = chamfer self.vis = vis_Valida self.j = 0 # Data loader # self.data_loader = data_loader self.train_loader = train_loader # TODO # exact model and loss self.model = args.model self.adv_loss = args.adv_loss # Model hyper-parameters self.imsize = args.imsize self.g_num = args.g_num self.z_dim = args.z_dim self.g_conv_dim = args.g_conv_dim self.d_conv_dim = args.d_conv_dim self.parallel = args.parallel self.lambda_gp = args.lambda_gp self.total_step = args.total_step self.d_iters = args.d_iters self.batch_size = args.batch_size self.num_workers = args.num_workers self.g_lr = args.g_lr self.d_lr = args.d_lr self.lr_decay = args.lr_decay self.beta1 = args.beta1 self.beta2 = args.beta2 self.pretrained_model = args.pretrained_model self.dataset = args.dataset self.use_tensorboard = args.use_tensorboard self.image_path = args.image_path self.log_path = args.log_path self.model_save_path = args.model_save_path self.sample_path = args.sample_path self.log_step = args.log_step self.sample_step = args.sample_step self.model_save_step = args.model_save_step self.version = args.version # Path self.log_path = os.path.join(args.log_path, self.version) self.sample_path = os.path.join(args.sample_path, self.version) self.model_save_path = os.path.join(args.model_save_path, self.version) self.build_model() if self.use_tensorboard: self.build_tensorboard() # Start with trained model if self.pretrained_model: self.load_pretrained_model() def train(self): # Data iterator # data_iter = iter(self.data_loader) train_iter = iter(self.train_loader) # TODO # step_per_epoch = len(self.data_loader) train_step_per_epoch = len(self.train_loader) # TODO # model_save_step = int(self.model_save_step * step_per_epoch) model_save_step = int(self.model_save_step * train_step_per_epoch) # TODO # Fixed input for debugging fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim)) # Start with trained model if self.pretrained_model: start = self.pretrained_model + 1 else: start = 0 # Start time start_time = time.time() for step in range(start, self.total_step): # ================== Train D ================== # self.D.train() self.G.train() try: # real_images, _ = next(data_iter) real_images = next(train_iter) # TODO except: # data_iter = iter(self.data_loader) train_iter = iter(self.train_loader) # TODO # real_images, _ = next(data_iter) real_images = next(train_iter) # Compute loss with real images # dr1, dr2, df1, df2, gf1, gf2 are attention scores real_images = tensor2var(real_images) d_out_real, dr1 = self.D(real_images) #,dr2 if self.adv_loss == 'wgan-gp': d_loss_real = -torch.mean(d_out_real) elif self.adv_loss == 'hinge': d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean() # apply Gumbel Softmax z = tensor2var(torch.randn(real_images.size(0), self.z_dim)) fake_images, gf1 = self.G(z) #,gf2 d_out_fake, df1 = self.D(fake_images) #,df2 if self.adv_loss == 'wgan-gp': d_loss_fake = d_out_fake.mean() elif self.adv_loss == 'hinge': d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean() # Backward + Optimize d_loss = d_loss_real + d_loss_fake self.reset_grad() d_loss.backward() self.d_optimizer.step() if self.adv_loss == 'wgan-gp': # Compute gradient penalty alpha = torch.rand(real_images.size(0), 1, 1, 1).cuda().expand_as(real_images) interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True) out, _ = self.D(interpolated) # TODO "_" grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones( out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # ================== Train G and gumbel ================== # # Create random noise z = tensor2var(torch.randn(real_images.size(0), self.z_dim)) fake_images, _ = self.G(z) # _ # Compute loss with fake images g_out_fake, _ = self.D(fake_images) # batch x n TODO "_" if self.adv_loss == 'wgan-gp': g_loss_fake = -g_out_fake.mean() elif self.adv_loss == 'hinge': g_loss_fake = -g_out_fake.mean() self.reset_grad() g_loss_fake.backward() self.g_optimizer.step() # Print out log info if (step + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print( "Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, " " ave_gamma_l3: {:.4f}, ave_gamma_l4: {:.4f}".format( elapsed, step + 1, self.total_step, (step + 1), self.total_step, d_loss_real.data[0], self.G.attn1.gamma.mean().data[0], self.G.attn2.gamma.mean().data[0])) # Sample images if (step + 1) % self.sample_step == 0: fake_images, _ = self.G(fixed_z) #TODO "_" encoded = fake_images.contiguous().view(64, 128) pc_1 = self.model_decoder(encoded) #pc_1_temp = pc_1[0, :, :] epoch = 0 for self.j in range(0, 10): pc_1_temp = pc_1[self.j, :, :] test = fixed_z.detach().cpu().numpy() test1 = np.asscalar(test[self.j, 0]) visuals = OrderedDict([('Validation Predicted_pc', pc_1_temp.detach().cpu().numpy())]) self.vis[self.j].display_current_results(visuals, epoch, step, z=test1) save_image( denorm(fake_images.data), os.path.join(self.sample_path, '{}_fake.png'.format(step + 1))) if (step + 1) % model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, '{}_G.pth'.format(step + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, '{}_D.pth'.format(step + 1))) def build_model(self): self.G = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim).cuda() self.D = Discriminator(self.batch_size, self.imsize, self.d_conv_dim).cuda() if self.parallel: self.G = nn.DataParallel(self.G) self.D = nn.DataParallel(self.D) # Loss and optimizer # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.g_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2]) self.c_loss = torch.nn.CrossEntropyLoss() # print networks print(self.G) print(self.D) def build_tensorboard(self): from logger import Logger self.logger = Logger(self.log_path) def load_pretrained_model(self): self.G.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_G.pth'.format(self.pretrained_model)))) self.D.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_D.pth'.format(self.pretrained_model)))) print('loaded trained models (step: {})..!'.format( self.pretrained_model)) def reset_grad(self): self.d_optimizer.zero_grad() self.g_optimizer.zero_grad() def save_sample(self, data_iter): real_images, _ = next(data_iter) save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png'))
class Solver(object): """Solver for training and testing StarGAN.""" def __init__(self, celeba_loader, rafd_loader, config): """Initialize configurations.""" # Data loader. self.celeba_loader = celeba_loader self.rafd_loader = rafd_loader # Model configurations. self.c_dim = config.c_dim self.c2_dim = config.c2_dim self.image_size = config.image_size self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.g_repeat_num = config.g_repeat_num self.d_repeat_num = config.d_repeat_num self.lambda_cls = config.lambda_cls self.lambda_rec = config.lambda_rec self.lambda_gp = config.lambda_gp # Training configurations. self.dataset = config.dataset self.batch_size = config.batch_size self.num_iters = config.num_iters self.num_iters_decay = config.num_iters_decay self.g_lr = config.g_lr self.d_lr = config.d_lr self.n_critic = config.n_critic self.beta1 = config.beta1 self.beta2 = config.beta2 self.resume_iters = config.resume_iters self.selected_attrs = config.selected_attrs # Test configurations. self.test_iters = config.test_iters # Miscellaneous. self.use_tensorboard = config.use_tensorboard self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') # Directories. self.log_dir = config.log_dir self.sample_dir = config.sample_dir self.model_save_dir = config.model_save_dir self.result_dir = config.result_dir # Step size. self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.lr_update_step = config.lr_update_step # Build the model and tensorboard. self.build_model() if self.use_tensorboard: self.build_tensorboard() def build_model(self): """Create a generator and a discriminator.""" if self.dataset in ['CelebA', 'RaFD']: self.G = Generator(self.batch_size, self.image_size, self.c_dim, self.g_conv_dim).cuda() self.D = Discriminator(self.batch_size, self.image_size, self.c_dim, self.d_conv_dim).cuda() # TODO add config: self.parallel (see line 195 in sagan/trainer.py) self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) self.print_network(self.G, 'G') self.print_network(self.D, 'D') self.G.to(self.device) self.D.to(self.device) def print_network(self, model, name): """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) def restore_model(self, resume_iters): """Restore the trained generator and discriminator.""" print( 'Loading the trained models from step {}...'.format(resume_iters)) G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) self.G.load_state_dict( torch.load(G_path, map_location=lambda storage, loc: storage)) self.D.load_state_dict( torch.load(D_path, map_location=lambda storage, loc: storage)) def build_tensorboard(self): """Build a tensorboard logger.""" from logger import Logger self.logger = Logger(self.log_dir) def update_lr(self, g_lr, d_lr): """Decay learning rates of the generator and discriminator.""" for param_group in self.g_optimizer.param_groups: param_group['lr'] = g_lr for param_group in self.d_optimizer.param_groups: param_group['lr'] = d_lr def reset_grad(self): """Reset the gradient buffers.""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() def denorm(self, x): """Convert the range from [-1, 1] to [0, 1].""" out = (x + 1) / 2 return out.clamp_(0, 1) def gradient_penalty(self, y, x): """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" weight = torch.ones(y.size()).to(self.device) dydx = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=weight, retain_graph=True, create_graph=True, only_inputs=True)[0] dydx = dydx.view(dydx.size(0), -1) dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) return torch.mean((dydx_l2norm - 1)**2) def label2onehot(self, labels, dim): """Convert label indices to one-hot vectors.""" batch_size = labels.size(0) out = torch.zeros(batch_size, dim) out[np.arange(batch_size), labels.long()] = 1 return out def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None): """Generate target domain labels for debugging and testing.""" # Get hair color indices. if dataset == 'CelebA': hair_color_indices = [] for i, attr_name in enumerate(selected_attrs): if attr_name in [ 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair' ]: hair_color_indices.append(i) c_trg_list = [] for i in range(c_dim): if dataset == 'CelebA': c_trg = c_org.clone() if i in hair_color_indices: # Set one hair color to 1 and the rest to 0. c_trg[:, i] = 1 for j in hair_color_indices: if j != i: c_trg[:, j] = 0 else: c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value. elif dataset == 'RaFD': c_trg = self.label2onehot(torch.ones(c_org.size(0)) * i, c_dim) c_trg_list.append(c_trg.to(self.device)) return c_trg_list def classification_loss(self, logit, target, dataset='CelebA'): """Compute binary or softmax cross entropy loss.""" if dataset == 'CelebA': return F.binary_cross_entropy_with_logits( logit, target, size_average=False) / logit.size(0) elif dataset == 'RaFD': return F.cross_entropy(logit, target) def train(self): """Train StarGAN within a single dataset.""" # Set data loader. if self.dataset == 'CelebA': data_loader = self.celeba_loader elif self.dataset == 'RaFD': data_loader = self.rafd_loader # Fetch fixed inputs for debugging. data_iter = iter(data_loader) x_fixed, c_org = next(data_iter) x_fixed = x_fixed.to(self.device) c_fixed_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) # Learning rate cache for decaying. g_lr = self.g_lr d_lr = self.d_lr # Start training from scratch or resume training. start_iters = 0 if self.resume_iters: start_iters = self.resume_iters self.restore_model(self.resume_iters) # Start training. print('Start training...') start_time = time.time() for i in range(start_iters, self.num_iters): # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # # Fetch real images and labels. try: x_real, label_org = next(data_iter) except: data_iter = iter(data_loader) x_real, label_org = next(data_iter) # Generate target domain labels randomly. rand_idx = torch.randperm(label_org.size(0)) label_trg = label_org[rand_idx] if self.dataset == 'CelebA': c_org = label_org.clone() c_trg = label_trg.clone() elif self.dataset == 'RaFD': c_org = self.label2onehot(label_org, self.c_dim) c_trg = self.label2onehot(label_trg, self.c_dim) x_real = x_real.to(self.device) # Input images. c_org = c_org.to(self.device) # Original domain labels. c_trg = c_trg.to(self.device) # Target domain labels. label_org = label_org.to( self.device) # Labels for computing classification loss. label_trg = label_trg.to( self.device) # Labels for computing classification loss. # =================================================================================== # # 2. Train the discriminator # # =================================================================================== # # TODO: hinge loss (see line 107 & 117 in sagan/trainer.py) # Compute loss with real images. # dr1, dr2, df1, df2, gfd1, gfd2, gfu1, gfu2 are attention scores out_src, out_cls, dr1, dr2 = self.D(x_real) d_loss_real = -torch.mean(out_src) # TODO: flip labels d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset) # Compute loss with fake images. x_fake, gfd1, gfd2, gfu1, gfu2 = self.G(x_real, c_trg) out_src, out_cls, df1, df2 = self.D(x_fake.detach()) d_loss_fake = torch.mean(out_src) # Compute loss for gradient penalty. alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True) out_src, _, _, _ = self.D(x_hat) d_loss_gp = self.gradient_penalty(out_src, x_hat) # Backward and optimize. d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging. loss = {} loss['D/loss_real'] = d_loss_real.item() loss['D/loss_fake'] = d_loss_fake.item() loss['D/loss_cls'] = d_loss_cls.item() loss['D/loss_gp'] = d_loss_gp.item() # =================================================================================== # # 3. Train the generator # # =================================================================================== # #if (i+1) % self.n_critic == 0: ## SA-GAN: Every time # Original-to-target domain. x_fake, _, _, _, _ = self.G(x_real, c_trg) out_src, out_cls, _, _ = self.D(x_fake) g_loss_fake = -torch.mean(out_src) g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset) # Target-to-original domain. x_reconst, _, _, _, _ = self.G(x_fake, c_org) g_loss_rec = torch.mean(torch.abs(x_real - x_reconst)) # Backward and optimize. g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging. loss['G/loss_fake'] = g_loss_fake.item() loss['G/loss_rec'] = g_loss_rec.item() loss['G/loss_cls'] = g_loss_cls.item() # =================================================================================== # # 4. Miscellaneous # # =================================================================================== # # Print out training information. if (i + 1) % self.log_step == 0: et = time.time() - start_time et = str(datetime.timedelta(seconds=et))[:-7] log = "Elapsed [{}], Iteration [{}/{}]".format( et, i + 1, self.num_iters) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary(tag, value, i + 1) # Translate fixed images for debugging. if (i + 1) % self.sample_step == 0: with torch.no_grad(): x_fake_list = [x_fixed] for c_fixed in c_fixed_list: x_fake, _, _, _, _ = self.G(x_fixed, c_fixed) x_fake_list.append(x_fake) x_concat = torch.cat(x_fake_list, dim=3) sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i + 1)) save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0) print('Saved real and fake images into {}...'.format( sample_path)) # Save model checkpoints. if (i + 1) % self.model_save_step == 0: G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i + 1)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i + 1)) torch.save(self.G.state_dict(), G_path) torch.save(self.D.state_dict(), D_path) print('Saved model checkpoints into {}...'.format( self.model_save_dir)) # Decay learning rates. if (i + 1) % self.lr_update_step == 0 and (i + 1) > ( self.num_iters - self.num_iters_decay): g_lr -= (self.g_lr / float(self.num_iters_decay)) d_lr -= (self.d_lr / float(self.num_iters_decay)) self.update_lr(g_lr, d_lr) print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format( g_lr, d_lr)) def test(self): """Translate images using StarGAN trained on a single dataset.""" # Load the trained generator. self.restore_model(self.test_iters) # Set data loader. if self.dataset == 'CelebA': data_loader = self.celeba_loader elif self.dataset == 'RaFD': data_loader = self.rafd_loader with torch.no_grad(): for i, (x_real, c_org) in enumerate(data_loader): # Prepare input images and target domain labels. x_real = x_real.to(self.device) c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) # Translate images. x_fake_list = [x_real] for c_trg in c_trg_list: x_fake, _, _, _, _ = self.G(x_real, c_trg) x_fake_list.append(x_fake) # Save the translated images. x_concat = torch.cat(x_fake_list, dim=3) result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i + 1)) save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) print('Saved real and fake images into {}...'.format( result_path))
class Tester(object): def __init__(self, data_loader, config): self.data_loader = data_loader # Model hyper-parameters self.imsize = config.imsize self.g_num = config.g_num self.z_dim = config.z_dim self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.parallel = config.parallel self.lambda_gp = config.lambda_gp self.total_step = config.total_step self.d_iters = config.d_iters self.batch_size = config.batch_size self.num_workers = config.num_workers self.g_lr = config.g_lr self.d_lr = config.d_lr self.lr_decay = config.lr_decay self.beta1 = config.beta1 self.beta2 = config.beta2 self.pretrained_model = config.pretrained_model self.dataset = config.dataset self.use_tensorboard = config.use_tensorboard self.image_path = config.image_path self.log_path = config.log_path self.model_save_path = config.model_save_path self.sample_path = config.sample_path self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.version = config.version self.model_save_path = os.path.join(config.model_save_path, self.version) self.test_path = config.test_path self.test_path = os.path.join(config.test_path, self.version) self.build_model() self.load_pretrained_model() def build_model(self): self.G = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim).cuda() self.D = Discriminator(self.batch_size, self.imsize, self.d_conv_dim).cuda() if self.parallel: self.G = nn.DataParallel(self.G) self.D = nn.DataParallel(self.D) # Loss and optimizer # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.g_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2]) def load_pretrained_model(self): self.G.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_G.pth'.format(self.pretrained_model)))) self.D.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_D.pth'.format(self.pretrained_model)))) print('loaded trained models (step: {})..!'.format( self.pretrained_model)) def test(self): num_of_images = 9 for i in range(500): z = tensor2var(torch.randn(num_of_images, self.z_dim)) fake_images, _, _ = self.G(z) #print(fake_images.data) #(9,3,w,h) -> (9,w,h,3) #print(fake_images.data.shape) #print(fake_images.data[0]) #save_image(denorm(fake_images.data), # os.path.join(self.test_path, '{}_fake.png'.format(i+1)), # nrow=3) #file_name = os.path.join(self.test_path, '{}_fake_class_format.png'.format(i+1)) #self.output_fig(denorm(fake_images.data), file_name) transpose_image = np.transpose(var2numpy(fake_images.data), (0, 2, 3, 1)) self.output_fig( transpose_image, os.path.join(self.test_path, '{}_fake_class_format.png'.format(i + 1))) def output_fig(self, images_array, file_name): # the shape of your images_array should be (9, width, height, 3), 28 <= width, height <= 112 plt.figure(figsize=(6, 6), dpi=100) plt.imshow(helper.images_square_grid(images_array)) plt.axis("off") plt.savefig(file_name, bbox_inches='tight', pad_inches=0)
class Trainer(object): def __init__(self, data_loader, config): # Data loader self.data_loader = data_loader # exact model and loss self.model = config.model self.adv_loss = config.adv_loss # Model hyper-parameters self.imsize = config.imsize self.g_num = config.g_num self.z_dim = config.z_dim self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.parallel = config.parallel self.lambda_gp = config.lambda_gp self.total_step = config.total_step self.d_iters = config.d_iters self.batch_size = config.batch_size self.num_workers = config.num_workers self.g_lr = config.g_lr self.d_lr = config.d_lr self.lr_decay = config.lr_decay self.beta1 = config.beta1 self.beta2 = config.beta2 self.pretrained_model = config.pretrained_model self.dataset = config.dataset self.use_tensorboard = config.use_tensorboard self.image_path = config.image_path self.log_path = config.log_path self.model_save_path = config.model_save_path self.sample_path = config.sample_path self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.version = config.version self.gpu = 'gpu' # Path self.log_path = os.path.join(config.log_path, self.version) self.sample_path = os.path.join(config.sample_path, self.version) self.model_save_path = os.path.join(config.model_save_path, self.version) self.path1 = config.path1 self.path2 = config.path2 self.dims = config.dims self.build_model() if self.use_tensorboard: self.build_tensorboard() # Start with trained model if self.pretrained_model: self.load_pretrained_model() def mytest(self, step): num = 200 z = tensor2var(torch.randn(num, self.z_dim)) # 32*128 fake_images, gf1, gf2 = self.G(z) # 1*3*64*64 fake_images = fake_images.data # inception_score(fake_images) # fake_images = fake_images.resize_((1, 3, 218, 178)) for n in range(num): save_image(fake_images[n], '/home/xinzi/dataset_k40/test_celeba/%d.jpg' % n) str0 = '/home/xinzi/dataset_k40/test_celeba/' + str(n) + '.jpg' im = Image.open(str0) im = im.resize((178, 218)) im.save(str0) ''' path = '/home/jingjie/xinzi/dataset/test' a = os.listdir(path) #a.sort() for file in a: file_path = os.path.join(path, file) if os.path.splitext(file_path)[1] == '.png': im = Image.open(file_path) ''' #args = parser.parse_args() #os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu # fid_value = calculate_fid_given_paths(self.path1, self.path2, 1, self.gpu != '', self.dims) # print('FID: ', fid_value) # a = 0 str1 = '/home/xinzi/dataset_k40/celebAtemp' times = 5 sum = 0 for i in range(times): shutil.rmtree(str1) os.mkdir(str1) random_copyfile(self.path1, str1, 40000) # args.path0 = str1 """ dims = 64 """ fid_value = calculate_fid_given_paths(str1, self.path2, 100, self.gpu != '', self.dims) sum = sum + fid_value print('FID: ', fid_value) print(float(sum / times)) f = open('FID.txt', 'a') f.write('\n') f.write(str(float(sum / times))) f.close() #d_out_fake, df1, df2 = self.D(fake_images) def train(self): # Data iterator data_iter = iter(self.data_loader) step_per_epoch = len(self.data_loader) model_save_step = int(self.model_save_step * step_per_epoch) # ????? # Fixed input for debugging fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim)) # Start with trained model if self.pretrained_model: start = self.pretrained_model + 1 else: start = 1 # Start time start_time = time.time() for step in range(start, self.total_step): # print(step) # ================== Train D ================== # self.D.train() self.G.train() try: # try...except...是异常检测的语句,try中的语句出现错误会执行except里面的语句 real_images, _ = next(data_iter) except: data_iter = iter(self.data_loader) real_images, _ = next(data_iter) # Compute loss with real images # dr1, dr2, df1, df2, gf1, gf2 are attention scores real_images = tensor2var( real_images) # 将real_images装到cuda以及Variable里 d_out_real, dr1, dr2 = self.D(real_images) if self.adv_loss == 'wgan-gp': d_loss_real = -torch.mean(d_out_real) # mean为求平均值 elif self.adv_loss == 'hinge': d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean() # apply Gumbel Softmax self.temp = real_images.size(0) z = tensor2var(torch.randn(real_images.size(0), self.z_dim)) fake_images, gf1, gf2 = self.G(z) d_out_fake, df1, df2 = self.D( fake_images) # 此处为什么不是fake_images.detach()? if self.adv_loss == 'wgan-gp': d_loss_fake = d_out_fake.mean() elif self.adv_loss == 'hinge': d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean() # Backward + Optimize d_loss = d_loss_real + d_loss_fake self.reset_grad() d_loss.backward() self.d_optimizer.step() if self.adv_loss == 'wgan-gp': # Compute gradient penalty alpha = torch.rand(real_images.size(0), 1, 1, 1).cuda().expand_as(real_images) interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True) out, _, _ = self.D(interpolated) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones( out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # ================== Train G and gumbel ================== # # Create random noise z = tensor2var(torch.randn(real_images.size(0), self.z_dim)) fake_images, _, _ = self.G(z) # Compute loss with fake images g_out_fake, _, _ = self.D(fake_images) # batch x n if self.adv_loss == 'wgan-gp': g_loss_fake = -g_out_fake.mean() elif self.adv_loss == 'hinge': g_loss_fake = -g_out_fake.mean() self.reset_grad() g_loss_fake.backward() self.g_optimizer.step() # Print out log info if (step + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print( "Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, " " ave_gamma_l3: , ave_gamma_l4: ".format( elapsed, step + 1, self.total_step, (step + 1), self.total_step, d_loss_real.data[0])) # Sample images if (step + 1) % self.sample_step == 0: fake_images, _, _ = self.G(fixed_z) save_image( denorm(fake_images.data), os.path.join(self.sample_path, '{}_fake.png'.format(step + 1))) if (step + 1) % model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, '{}_G.pth'.format(step + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, '{}_D.pth'.format(step + 1))) if step >= 3000: if step % 400 == 0: print('====================testing====================') self.mytest(step) def build_model(self): self.G = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim).cuda() self.D = Discriminator(self.batch_size, self.imsize, self.d_conv_dim).cuda() if self.parallel: self.G = nn.DataParallel(self.G) self.D = nn.DataParallel(self.D) # Loss and optimizer # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.g_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2]) self.c_loss = torch.nn.CrossEntropyLoss() # print networks print(self.G) print(self.D) def build_tensorboard(self): from logger import Logger self.logger = Logger(self.log_path) def load_pretrained_model(self): self.G.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_G.pth'.format(self.pretrained_model)))) self.D.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_D.pth'.format(self.pretrained_model)))) print('loaded trained models (step: {})..!'.format( self.pretrained_model)) def reset_grad(self): self.d_optimizer.zero_grad() self.g_optimizer.zero_grad() def save_sample(self, data_iter): real_images, _ = next(data_iter) save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png'))
class Trainer(object): def __init__(self, loader, config, data_loader_val=None): # Data loader data_loader, data_loader_val = loader self.data_loader = data_loader self.data_loader_val = data_loader_val # exact model and loss self.model = config.model self.adv_loss = config.adv_loss # Model hyper-parameters self.imsize = config.imsize self.g_num = config.g_num self.z_dim = config.z_dim self.cam_view_z = (20 + 40 + 10) * 2 + 5 self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.parallel = config.parallel self.lambda_gp = config.lambda_gp self.total_step = config.total_step self.d_iters = config.d_iters self.batch_size = config.batch_size self.num_workers = config.num_workers self.g_lr = config.g_lr self.d_lr = config.d_lr self.lr_decay = config.lr_decay self.beta1 = config.beta1 self.beta2 = config.beta2 self.pretrained_model = config.pretrained_model self.dataset = config.dataset self.use_tensorboard = config.use_tensorboard self.image_path = config.image_path self.log_path = config.log_path self.model_save_path = config.model_save_path self.sample_path = config.sample_path self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.version = config.version # Path self.log_path = os.path.join(config.log_path, self.version) self.sample_path = os.path.join(config.sample_path, self.version) self.vae_rec_path = os.path.join(config.sample_path, "vae_rec") os.makedirs(self.vae_rec_path, exist_ok=True) # TODO print('vae_rec_path: {}'.format(self.vae_rec_path)) self.model_save_path = os.path.join(config.model_save_path, self.version) self.num_pixels = self.imsize * 2 * 3 self.build_model() if self.use_tensorboard: self.build_tensorboard() # Start with trained model if self.pretrained_model: self.load_pretrained_model() def train(self): def cycle(iterable): while True: for x in iterable: yield x # Using itertools.cycle has an important drawback, in that it does not shuffle the data after each iteration: # WARNING itertools.cycle does not shuffle the data after each iteratio # Data iterator data_iter = iter(cycle(self.data_loader)) self.loader_val_iter = iter(cycle(self.data_loader_val)) step_per_epoch = len(self.data_loader) model_save_step = int(self.model_save_step * step_per_epoch) # Fixed input for debugging fixed_z = None # Start with trained model if self.pretrained_model: start = self.pretrained_model + 1 else: start = 0 # Start time start_time = time.time() num_views = 2 key_views = ["frames views {}".format(i) for i in range(num_views)] lable_keys_cam_view_info = [] # list with keys for view 0 and view 1 for view_i in range(num_views): lable_keys_cam_view_info.append([ "cam_pitch_view_{}".format(view_i), "cam_yaw_view_{}".format(view_i), "cam_distance_view_{}".format(view_i) ]) mapping_cam_info_lable = OrderedDict() mapping_cam_info_one_hot = OrderedDict() # create a different mapping for echt setting n_classes = [] for cam_info_view in lable_keys_cam_view_info: for cam_inf in cam_info_view: if "pitch" in cam_inf: min_val, max_val = -50, -35. n_bins = 20 elif "yaw" in cam_inf: min_val, max_val = -60., 210. n_bins = 40 elif "distance" in cam_inf: min_val, max_val = 0.7, 1. n_bins = 10 to_l, to_hot_l = create_lable_func(min_val, max_val, n_bins) mapping_cam_info_lable[cam_inf] = to_l mapping_cam_info_one_hot[cam_inf] = to_hot_l if "view_0" in cam_inf: n_classes.append(n_bins) print('cam view one hot infputs {}'.format(n_classes)) task_progess_bins = 5 _, task_progress_hot_func = create_lable_func(0, 115, n_bins=task_progess_bins, clip=True) assert sum(n_classes) * 2 + task_progess_bins == self.cam_view_z def changing_factor(start, end, steps): for i in range(steps): yield i / (steps / (end - start)) + start cycle_factor_gen = changing_factor(0.5, 1., self.total_step) triplet_factor_gen = changing_factor(0.1, 1., self.total_step) for step in range(start, self.total_step): # ================== Train D ================== # self.D.train() self.G.train() if isinstance(self.data_loader.dataset, DoubleViewPairDataset): data = next(data_iter) key_views, lable_keys_cam_view_info = shuffle( key_views, lable_keys_cam_view_info) # real_images = torch.cat([data[key_views[0]], data[key_views[1]]]) # for now only view 0 real_images = data[key_views[0]] real_images_view1 = data[key_views[1]] label_c = OrderedDict() label_c_hot_in = OrderedDict() for key_l, lable_func in mapping_cam_info_lable.items(): # contin cam values to labels label_c[key_l] = torch.tensor(lable_func( data[key_l])).cuda() label_c_hot_in[key_l] = torch.tensor( mapping_cam_info_one_hot[key_l](data[key_l]), dtype=torch.float32).cuda() d_one_hot_view0 = [ label_c_hot_in[l] for l in lable_keys_cam_view_info[0] ] d_one_hot_view1 = [ label_c_hot_in[l] for l in lable_keys_cam_view_info[1] ] d_task_progress = torch.tensor(task_progress_hot_func( data['frame index']), dtype=torch.float32).cuda() else: real_images, _ = next(data_iter) # Compute loss with real images # dr1, dr2, df1, df2, gf1, gf2 are attention scores real_images = tensor2var(real_images) real_images_view1 = tensor2var(real_images_view1) d_out_real, dr1, dr2 = self.D(real_images) if self.adv_loss == 'wgan-gp': d_loss_real = -torch.mean(d_out_real) elif self.adv_loss == 'hinge': d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean() # apply Gumbel Softmax encoded = self.G.encoder(real_images) sampled = self.G.encoder.sampler(encoded) z = torch.randn(real_images.size(0), self.z_dim).cuda() z = torch.cat( [*d_one_hot_view0, *d_one_hot_view1, d_task_progress, sampled], dim=1) # add view info from to if fixed_z is None: fixed_z = tensor2var( torch.cat([ *d_one_hot_view0, *d_one_hot_view1, d_task_progress, sampled ], dim=1)) # add view info z = tensor2var(z) fake_images, gf1, gf2 = self.G(z) d_out_fake, df1, df2 = self.D(fake_images) if self.adv_loss == 'wgan-gp': d_loss_fake = d_out_fake.mean() elif self.adv_loss == 'hinge': d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean() # Backward + Optimize d_loss = d_loss_real + d_loss_fake self.reset_grad() d_loss.backward() self.d_optimizer.step() if self.adv_loss == 'wgan-gp': # Compute gradient penalty alpha = torch.rand(real_images.size(0), 1, 1, 1).cuda().expand_as(real_images) interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True) out, _, _ = self.D(interpolated) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones( out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # ================== Train VAE================== # encoded = self.G.encoder(real_images) mu_0 = encoded[0] logvar = encoded[1] KLD_element = mu_0.pow(2).add_( logvar.exp()).mul_(-1).add_(1).add_(logvar) KLD = torch.sum(KLD_element).mul_(-0.5) # save_image(denorm(real_images[::2]), os.path.join(self.sample_path, "ancor.png")) # save_image(denorm(real_images[1::2]), os.path.join(self.sample_path, "neg.png")) # save_image(denorm(real_images_view1[::2]), os.path.join(self.sample_path, "pos.png")) sampled = self.G.encoder.sampler(encoded) z = torch.cat( [*d_one_hot_view0, *d_one_hot_view1, d_task_progress, sampled], dim=1) # add view info 0 z = tensor2var(z) fake_images_0, _, _ = self.G(z) MSEerr = self.MSECriterion(fake_images_0, real_images_view1) rec = fake_images_0 VAEerr = MSEerr + KLD * 0.1 # encode the fake view and recon loss to view1 encoded = self.G.encoder(fake_images_0) mu_1 = encoded[0] logvar = encoded[1] KLD_element = mu_1.pow(2).add_( logvar.exp()).mul_(-1).add_(1).add_(logvar) KLD = torch.sum(KLD_element).mul_(-0.5) sampled = self.G.encoder.sampler(encoded) z = torch.cat( [*d_one_hot_view1, *d_one_hot_view0, d_task_progress, sampled], dim=1) # add view info 1 z = tensor2var(z) fake_images_view1, _, _ = self.G(z) rec_fake = fake_images_view1 MSEerr = self.MSECriterion(fake_images_view1, real_images) VAEerr += (KLD * 0.1 + MSEerr) * next( cycle_factor_gen) # (KLD + MSEerr) # *0.5 triplet_loss = self.triplet_loss(anchor=mu_0[::2], positive=mu_0[1::2], negative=mu_1[::2]) # ================== Train G and gumbel ================== # # Create random noise # z = tensor2var(torch.randn(real_images.size(0), self.z_dim)) # fake_images, _, _ = self.G(z) # Compute loss with fake images # fake_images = torch.cat([fake_images_0, fake_images_view1]) # rm triplets fake_images = torch.cat( [fake_images_0[::2], fake_images_view1[::2]]) # rm triplets g_out_fake, _, _ = self.D(fake_images) # batch x n if self.adv_loss == 'wgan-gp': g_loss_fake = -g_out_fake.mean() elif self.adv_loss == 'hinge': g_loss_fake = -g_out_fake.mean() self.reset_grad() VAEerr *= self.num_pixels triplet_loss *= self.num_pixels loss = g_loss_fake * 4. + VAEerr + triplet_loss * next( triplet_factor_gen) loss.backward() self.g_optimizer.step() # Print out log info if (step + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print( "Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, " " ave_gamma_l3: {:.4f}, ave_gamma_l4: {:.4f},vae {:.4f}". format(elapsed, step + 1, self.total_step, (step + 1), self.total_step, d_loss_real, self.G.attn1.gamma.mean(), self.G.attn2.gamma.mean(), VAEerr)) if vis is not None: kw_update_vis = None if self.d_plot is not None: kw_update_vis = 'append' # kw_update_vis["update"] = 'append' self.d_plot = vis.line(np.array( [d_loss_real.detach().cpu().numpy()]), X=np.array([step]), win=self.d_plot, update=kw_update_vis, opts=dict(title="d_loss_real", xlabel='Timestep', ylabel='loss')) self.d_plot_fake = vis.line(np.array( [d_loss_fake.detach().cpu().numpy()]), X=np.array([step]), win=self.d_plot_fake, update=kw_update_vis, opts=dict(title="d_loss_fake", xlabel='Timestep', ylabel='loss')) self.d_plot_vae = vis.line(np.array( [VAEerr.detach().cpu().numpy()]), X=np.array([step]), win=self.d_plot_vae, update=kw_update_vis, opts=dict(title="VAEerr", xlabel='Timestep', ylabel='loss')) self.d_plot_triplet_loss = vis.line( np.array([triplet_loss.detach().cpu().numpy()]), X=np.array([step]), win=self.d_plot_triplet_loss, update=kw_update_vis, opts=dict(title="triplet_loss", xlabel='Timestep', ylabel='loss')) # Sample images if (step + 1) % self.sample_step == 0: fake_images, _, _ = self.G(fixed_z) fake_images = denorm(fake_images) save_image( fake_images.data, os.path.join(self.sample_path, '{}_fake.png'.format(step + 1))) n = 8 imgs = denorm(torch.cat([real_images.data[:n], rec.data[:n]])) imgs_rec_fake = denorm( torch.cat([real_images_view1.data[:n], rec_fake.data[:n]])) title = '{}_var_rec_real'.format(step + 1) title_rec_fake = '{}_var_rec_fake'.format(step + 1) title_fixed = '{}_fixed'.format(step + 1) save_image(imgs, os.path.join(self.vae_rec_path, title + ".png"), nrow=n) distance_pos, product_pos, distance_neg, product_neg = self._get_view_pair_distances( ) print("distance_pos {:.3}, neg {:.3},dot pos {:.3} neg {:.3}". format(distance_pos, distance_neg, product_pos, product_neg)) if vis is not None: self.rec_win = vis.images( imgs, win=self.rec_win, opts=dict(title=title, width=64 * n, height=64 * 2), ) self.rec_fake_win = vis.images( imgs_rec_fake, win=self.rec_fake_win, opts=dict(title=title_rec_fake, width=64 * n, height=64 * 2), ) self.fixed_win = vis.images( fake_images, win=self.fixed_win, opts=dict(title=title_fixed, width=64 * n, height=64 * 4), ) kw_update_vis = None if self.d_plot_distance_pos is not None: kw_update_vis = 'append' self.d_plot_distance_pos = vis.line( np.array([distance_pos]), X=np.array([step]), win=self.d_plot_distance_pos, update=kw_update_vis, opts=dict(title="distance pos", xlabel='Timestep', ylabel='dist')) self.d_plot_distance_neg = vis.line( np.array([distance_neg]), X=np.array([step]), win=self.d_plot_distance_neg, update=kw_update_vis, opts=dict(title="distance neg", xlabel='Timestep', ylabel='dist')) if (step + 1) % model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, '{}_G.pth'.format(step + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, '{}_D.pth'.format(step + 1))) def _get_view_pair_distances(self): def encode(x): encoded = self.G.encoder(x) mu = encoded[0] return mu # dot product are mean free key_views = ["frames views {}".format(i) for i in range(2)] sample_batched = next(self.loader_val_iter) anchor_emb = encode(sample_batched[key_views[0]].cuda()) positive_emb = encode(sample_batched[key_views[1]].cuda()) distance_pos = np.linalg.norm(anchor_emb.data.cpu().numpy() - positive_emb.data.cpu().numpy(), axis=1).mean() dots = [] for e1, e2 in zip(anchor_emb.data.cpu().numpy(), positive_emb.data.cpu().numpy()): dots.append(np.dot(e1 - e1.mean(), e2 - e2.mean())) product_pos = np.mean(dots) n = len(anchor_emb) emb_pos = anchor_emb.data.cpu().numpy() emb_neg = positive_emb.data.cpu().numpy() cnt, distance_neg, product_neg = 0., 0., 0. for i in range(n): for j in range(n): if j != i: d_negative = np.linalg.norm(emb_pos[i] - emb_neg[j]) distance_neg += d_negative product_neg += np.dot(emb_pos[i] - emb_pos[i].mean(), emb_neg[j] - emb_neg[j].mean()) cnt += 1 distance_neg /= cnt product_neg /= cnt # distance_pos = np.asscalar(distance_pos) # product_pos = np.asscalar(product_pos) # distance_neg = np.asscalar(distance_neg) # product_neg = np.asscalar(product_neg) return distance_pos, product_pos, distance_neg, product_neg def build_model(self): self.rec_win = None self.rec_fake_win = None self.fixed_win = None self.d_plot = None self.d_plot_fake = None self.d_plot_vae = None self.d_plot_triplet_loss = None self.d_plot_distance_neg = None self.d_plot_distance_pos = None self.G = Generator(self.batch_size, self.imsize, self.z_dim, self.cam_view_z, self.g_conv_dim).cuda() self.D = Discriminator(self.batch_size, self.imsize, self.d_conv_dim).cuda() if self.parallel: self.G = nn.DataParallel(self.G) self.D = nn.DataParallel(self.D) self.MSECriterion = nn.MSELoss() self.triplet_loss = nn.TripletMarginLoss() # Loss and optimizer # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.g_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2]) self.c_loss = torch.nn.CrossEntropyLoss() # print networks print(self.G) print(self.D) def build_tensorboard(self): from logger import Logger self.logger = Logger(self.log_path) def load_pretrained_model(self): self.G.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_G.pth'.format(self.pretrained_model)))) self.D.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_D.pth'.format(self.pretrained_model)))) print('loaded trained models (step: {})..!'.format( self.pretrained_model)) def reset_grad(self): self.d_optimizer.zero_grad() self.g_optimizer.zero_grad() def save_sample(self, data_iter): real_images, _ = next(data_iter) save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png')) def save_sample(self, data_iter): real_images, _ = next(data_iter) save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png'))
class Trainer(object): def __init__(self, data_loader, config): # Data loader self.data_loader = data_loader # exact model and loss self.model = config.model self.adv_loss = config.adv_loss # Model hyper-parameters self.imsize = config.imsize self.g_num = config.g_num self.z_dim = config.z_dim self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.parallel = config.parallel self.lambda_gp = config.lambda_gp self.total_step = config.total_step self.d_iters = config.d_iters self.batch_size = config.batch_size self.num_workers = config.num_workers self.g_lr = config.g_lr self.d_lr = config.d_lr self.lr_decay = config.lr_decay self.beta1 = config.beta1 self.beta2 = config.beta2 self.pretrained_model = config.pretrained_model self.dataset = config.dataset self.use_tensorboard = config.use_tensorboard self.image_path = config.image_path self.log_path = config.log_path self.model_save_path = config.model_save_path self.sample_path = config.sample_path self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.version = config.version # Path self.log_path = os.path.join(config.log_path, self.version) self.sample_path = os.path.join(config.sample_path, self.version) self.model_save_path = os.path.join(config.model_save_path, self.version) self.build_model() if self.use_tensorboard: self.build_tensorboard() # Start with trained model if self.pretrained_model: self.load_pretrained_model() def train(self): # Data iterator data_iter = iter(self.data_loader) step_per_epoch = len(self.data_loader) model_save_step = int(self.model_save_step * step_per_epoch) # Fixed input for debugging fixed_z = tensor2var( torch.normal(0, torch.ones([self.batch_size, self.z_dim]) * 3)) # Start with trained model if self.pretrained_model: start = self.pretrained_model + 1 else: start = 0 # Start time start_time = time.time() i = 0 for step in range(start, self.total_step): # ================== Train D ================== # self.D.train() self.G.train() try: (real_images, _) = next(data_iter) except: data_iter = iter(self.data_loader) (real_images, _) = next(data_iter) # Compute loss with real images # dr1, dr2, df1, df2, gf1, gf2 are attention scores real_images = tensor2var(real_images) d_out_real = self.D(real_images) if self.adv_loss == 'wgan-gp': d_loss_real = -torch.mean(d_out_real) elif self.adv_loss == 'hinge': d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean() # apply Gumbel Softmax z = tensor2var( torch.normal(0, torch.ones([real_images.size(0), self.z_dim]) * 3)) # (fake_images, gf1, gf2) = self.G(z) (fake_images, gf2) = self.G(z) if i < 1: print('***** Result Image size now *****') print(fake_images.size()) # print(gf1.size()) print(gf2.size()) i = i + 1 d_out_fake = self.D(fake_images) if self.adv_loss == 'wgan-gp': d_loss_fake = d_out_fake.mean() elif self.adv_loss == 'hinge': d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean() # Backward + Optimize d_loss = d_loss_real + d_loss_fake self.reset_grad() d_loss.backward() self.d_optimizer.step() if self.adv_loss == 'wgan-gp': # Compute gradient penalty alpha = torch.rand(real_images.size(0), 1, 1, 1).cuda().expand_as(real_images) interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True) out = self.D(interpolated) grad = torch.autograd.grad( outputs=out, inputs=interpolated, grad_outputs=torch.ones(out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True, )[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # ================== Train G and gumbel ================== # # Create random noise z = tensor2var( torch.normal(0, torch.ones([real_images.size(0), self.z_dim]) * 3)) # (fake_images, _, _) = self.G(z) (fake_images, _) = self.G(z) # Compute loss with fake images g_out_fake = self.D(fake_images) # batch x n if self.adv_loss == 'wgan-gp': g_loss_fake = -g_out_fake.mean() elif self.adv_loss == 'hinge': g_loss_fake = -g_out_fake.mean() self.reset_grad() g_loss_fake.backward() self.g_optimizer.step() # Print out log info if (step + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print( 'Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, ave_gamma_l4: {:.4f}' .format( elapsed, step + 1, self.total_step, step + 1, self.total_step, d_loss_real.data[0], self.G.module.attn2.gamma.mean().data[0], )) # (1) Log values of the losses (scalars) info = { 'd_loss_real': d_loss_real.data[0], 'd_loss_fake': d_loss_fake.data[0], 'd_loss': d_loss.data[0], 'g_loss_fake': g_loss_fake.data[0], # 'ave_gamma_l3': self.G.module.attn1.gamma.mean().data[0], 'ave_gamma_l4': self.G.module.attn2.gamma.mean().data[0], } for (tag, value) in info.items(): self.logger.scalar_summary(tag, value, step + 1) # Sample images / Save and log if (step + 1) % self.sample_step == 0: # (2) Log values and gradients of the parameters (histogram) for (net, name) in zip([self.G, self.D], ['G_', 'D_']): for (tag, value) in net.named_parameters(): tag = name + tag.replace('.', '/') self.logger.histo_summary(tag, value.data.cpu().numpy(), step + 1) # (3) Log the tensorboard info = \ {'fake_images': (fake_images.view(fake_images.size())[: 16, :, :, :]).data.cpu().numpy(), 'real_images': (real_images.view(real_images.size())[: 16, :, :, :]).data.cpu().numpy()} # (fake_images, at1, at2) = self.G(fixed_z) (fake_images, at2) = self.G(fixed_z) if (step + 1) % (self.sample_step * 10) == 0: save_image( denorm(fake_images.data), os.path.join(self.sample_path, '{}_fake.png'.format(step + 1))) # print('***** Fake Image size now *****') # print('fake_images ', fake_images.size()) # print('at2 ', at2.size()) # B * N * N at2_4d = at2.view( *(at2.size()[0], at2.size()[1], int(math.sqrt(at2.size()[2])), int(math.sqrt(at2.size()[2])))) # W * N * W * H # print('at2_4d ', at2_4d.size()) at2_mean = at2_4d.mean(dim=1, keepdim=False) # B * W * H # print('at2_mean ', at2_mean.size()) print('***** start create activation map *****') attn_list = [] for i in range(at2.size()[0]): # print('fake_images size: ',fake_images[i].size()) # print('at2 mean size', at2_mean[i].size()) f = BytesIO() img = np.uint8( fake_images[i, :, :, :].mul(255).data.cpu().numpy()) a = np.uint8(at2_mean[i, :, :].mul(255).data.cpu().numpy()) # print('image: ', img.shape) # print('a shape: ',a.shape) im_image = img.reshape(img.shape[1], img.shape[2], img.shape[0]) im_attn = cv2.applyColorMap(a, cv2.COLORMAP_JET) img_with_heatmap = np.float32(im_attn) + np.float32( im_image) img_with_heatmap = img_with_heatmap / np.max( img_with_heatmap) attn_np = np.uint8((255 * img_with_heatmap).reshape( img_with_heatmap.shape[2], img_with_heatmap.shape[0], img_with_heatmap.shape[1])) attn_torch = torch.from_numpy(attn_np) # print('final attn image size: ', attn_torch.size()) attn_list.append(attn_torch.unsqueeze(0)) attn_images = torch.cat(attn_list) print('attn images list: ', attn_images.size()) info['attn_images'] = (attn_images.view( attn_images.size())[:16, :, :, :]).numpy() for (tag, image) in info.items(): self.logger.image_summary(tag, image, step + 1) if (step + 1) % model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, '{}_G.pth'.format(step + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, '{}_D.pth'.format(step + 1))) def build_model(self): self.G = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim).cuda() self.D = Discriminator(self.batch_size, self.imsize, self.d_conv_dim).cuda() if self.parallel: self.G = nn.DataParallel(self.G) self.D = nn.DataParallel(self.D) # Loss and optimizer # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.g_optimizer = torch.optim.Adam(filter(lambda p: \ p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(filter(lambda p: \ p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2]) self.c_loss = torch.nn.CrossEntropyLoss() # print networks print(self.G) print(self.D) def build_tensorboard(self): from logger import Logger #if os.path.exists(self.log_path): # shutil.rmtree(self.log_path) #os.makedirs(self.log_path) self.logger = Logger(self.log_path) def load_pretrained_model(self): self.G.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_G.pth'.format(self.pretrained_model)))) self.D.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_D.pth'.format(self.pretrained_model)))) print('loaded trained models (step: {})..!'.format( self.pretrained_model)) def reset_grad(self): self.d_optimizer.zero_grad() self.g_optimizer.zero_grad() def save_sample(self, data_iter): (real_images, _) = next(data_iter) save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png')) def save_gradient_images(self, gradient, file_name): """ Exports the original gradient image Args: gradient (np arr): Numpy array of the gradient with shape (3, 224, 224) file_name (str): File name to be exported """ if not os.path.exists('attn2/results'): os.makedirs('attn2/results') # Normalize gradient = gradient - gradient.min() gradient /= gradient.max() # Save image path = os.path.join('attn2/results', file_name + '.jpg') im = gradient if isinstance(im, np.ndarray): if len(im.shape) == 2: im = np.expand_dims(im, axis=0) if im.shape[0] == 1: # Converting an image with depth = 1 to depth = 3, repeating the same values # For some reason PIL complains when I want to save channel image as jpg without # additional format in the .save() im = np.repeat(im, 3, axis=0) # Convert to values to range 1-255 and W,H, D if im.shape[0] == 3: im = im.transpose(1, 2, 0) * 255 im = Image.fromarray(im.astype(np.uint8)) im.save(path)
class Tester(object): def __init__(self, data_loader, config): self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.data_loader = data_loader # exact model and loss self.model = config.model self.adv_loss = config.adv_loss # Model hyper-parameters self.imsize = config.imsize self.g_num = config.g_num self.z_dim = config.z_dim self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.parallel = config.parallel self.lambda_gp = config.lambda_gp self.total_step = config.total_step self.d_iters = config.d_iters self.batch_size = config.batch_size self.num_workers = config.num_workers self.ge_lr = config.ge_lr self.d_lr = config.d_lr self.lr_decay = config.lr_decay self.beta1 = config.beta1 self.beta2 = config.beta2 self.pretrained_model = config.pretrained_model self.dataset = config.dataset self.mura_class = config.mura_class self.mura_type = config.mura_type self.use_tensorboard = config.use_tensorboard self.image_path = config.image_path self.log_path = config.log_path self.model_save_path = config.model_save_path self.sample_path = config.sample_path self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.version = config.version # Path self.log_path = os.path.join(config.log_path, self.version) self.sample_path = os.path.join(config.sample_path, self.version) self.model_save_path = os.path.join(config.model_save_path, self.version) # Build tensorboard for debugiing self.build_tensorboard() # Build model self.build_model() # Load models self.load_pretrained_model() def test(self): data_iter = iter(self.data_loader) self.D.eval() self.E.eval() self.G.eval() with torch.no_grad(): for i, data in enumerate(data_iter): val_images, val_labels = data val_images = tensor2var(val_images) # Run val images through models X -> E(X) -> G(E(X)) z, ef1, ef2 = self.E(val_images) re_images, gf1, gf2 = self.G(z) dv, dv5, dv4, dv3, dvz, dva2, dva1 = self.D(val_images, z) dr, dr5, dr4, dr3, drz, dra2, dra1 = self.D(re_images, z) # Compute residual loss l1 = (re_images - val_images).abs() l2 = (re_images - val_images).pow(2).sqrt() # Computer feature matching loss ld = (dv - dr).abs().view((self.batch_size, -1)).mean(dim=1) ld5 = (dv5 - dr5).abs().view((self.batch_size, -1)).mean(dim=1) ld4 = (dv4 - dr4).abs().view((self.batch_size, -1)).mean(dim=1) ld3 = (dv3 - dr3).abs().view((self.batch_size, -1)).mean(dim=1) import ipdb ipdb.set_trace() plt.scatter(range(1, self.batch_size + 1), l1, c=val_labels) def build_tensorboard(self): '''Initialize tensorboard writer''' self.writer = SummaryWriter(self.log_path) def build_model(self): self.G = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim).to(self.device) self.E = Encoder(self.batch_size, self.imsize, self.z_dim, self.d_conv_dim).to(self.device) self.D = Discriminator(self.batch_size, self.imsize, self.z_dim, self.d_conv_dim).to(self.device) if self.parallel: self.G = nn.DataParallel(self.G) self.E = nn.DataParallel(self.E) self.D = nn.DataParallel(self.D) # Loss and optimizer self.ge_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, itertools.chain(self.G.parameters(), self.E.parameters())), self.ge_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2]) self.c_loss = torch.nn.CrossEntropyLoss() # print networks print(self.G) print(self.E) print(self.D) def load_pretrained_model(self): self.G.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_G.pth'.format(self.pretrained_model)))) self.E.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_E.pth'.format(self.pretrained_model)))) self.D.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_D.pth'.format(self.pretrained_model)))) print('loaded trained models (step: {})..!'.format( self.pretrained_model)) def reset_grad(self): self.d_optimizer.zero_grad() self.ge_optimizer.zero_grad() def save_sample(self, data_iter): real_images, _ = next(data_iter) save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png'))