class Solver(object): def __init__(self, face_data_loader, config): # Data loader self.face_data_loader = face_data_loader # Model parameters self.y_dim = config.y_dim self.num_layers = config.num_layers self.im_size = config.im_size self.g_first_dim = config.g_first_dim self.d_first_dim = config.d_first_dim self.enc_repeat_num = config.enc_repeat_num self.d_repeat_num = config.d_repeat_num self.d_train_repeat = config.d_train_repeat # Hyper-parameteres self.lambda_cls = config.lambda_cls self.lambda_id = config.lambda_id self.lambda_bi = config.lambda_bi self.lambda_gp = config.lambda_gp self.enc_lr = config.enc_lr self.dec_lr = config.dec_lr self.d_lr = config.d_lr self.beta1 = config.beta1 self.beta2 = config.beta2 # Training settings self.num_epochs = config.num_epochs self.num_epochs_decay = config.num_epochs_decay self.num_iters = config.num_iters self.num_iters_decay = config.num_iters_decay self.batch_size = config.batch_size self.trained_model = config.trained_model # Test settings self.test_model = config.test_model # Path self.log_path = config.log_path self.sample_path = config.sample_path self.model_path = config.model_path self.test_path = config.test_path # Step size self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step # Set tensorboard self.build_model() self.use_tensorboard() # Start with trained model if self.trained_model: self.load_trained_model() def build_model(self): # Define encoder-decoder (generator) and a discriminator self.Enc = Encoder(self.g_first_dim, self.enc_repeat_num) self.Dec = Decoder(self.g_first_dim) self.D = Discriminator(self.im_size, self.d_first_dim, self.d_repeat_num) # Optimizers self.enc_optimizer = torch.optim.Adam(self.Enc.parameters(), self.enc_lr, [self.beta1, self.beta2]) self.dec_optimizer = torch.optim.Adam(self.Dec.parameters(), self.dec_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) if torch.cuda.is_available(): self.Enc.cuda() self.Dec.cuda() self.D.cuda() def load_trained_model(self): self.Enc.load_state_dict( torch.load( os.path.join(self.model_path, '{}_Enc.pth'.format(self.trained_model)))) self.Dec.load_state_dict( torch.load( os.path.join(self.model_path, '{}_Dec.pth'.format(self.trained_model)))) self.D.load_state_dict( torch.load( os.path.join(self.model_path, '{}_D.pth'.format(self.trained_model)))) print('loaded models (step: {})..!'.format(self.trained_model)) def use_tensorboard(self): from tensorboard_logger import Logger self.logger = Logger(self.log_path) def update_lr(self, enc_lr, dec_lr, d_lr): for param_group in self.enc_optimizer.param_groups: param_group['lr'] = enc_lr for param_group in self.dec_optimizer.param_groups: param_group['lr'] = dec_lr for param_group in self.d_optimizer.param_groups: param_group['lr'] = d_lr def reset(self): self.enc_optimizer.zero_grad() self.dec_optimizer.zero_grad() self.d_optimizer.zero_grad() def to_var(self, x, volatile=False): if torch.cuda.is_available(): x = x.cuda() return Variable(x, volatile=volatile) def calculate_accuracy(self, x, y): _, predicted = torch.max(x, dim=1) correct = (predicted == y).float() accuracy = torch.mean(correct) * 100.0 return accuracy def denorm(self, x): out = (x + 1) / 2 return out.clamp_(0, 1) def one_hot(self, labels, dim): """Convert label indices to one-hot vector""" batch_size = labels.size(0) out = torch.zeros(batch_size, dim) out[np.arange(batch_size), labels.long()] = 1 return out def train(self): """Train attribute-guided face image synthesis model""" self.data_loader = self.face_data_loader # The number of iterations for each epoch iters_per_epoch = len(self.data_loader) sample_x = [] sample_l = [] real_y = [] for i, (images, landmark) in enumerate(self.data_loader): labels = images[1] sample_x.append(images[0]) sample_l.append(landmark[0]) real_y.append(labels) if i == 2: break # Sample inputs and desired domain labels for testing sample_x = torch.cat(sample_x, dim=0) sample_x = self.to_var(sample_x, volatile=True) sample_l = torch.cat(sample_l, dim=0) sample_l = self.to_var(sample_l, volatile=True) real_y = torch.cat(real_y, dim=0) sample_y_list = [] for i in range(self.y_dim): sample_y = self.one_hot( torch.ones(sample_x.size(0)) * i, self.y_dim) sample_y_list.append(self.to_var(sample_y, volatile=True)) # Learning rate for decaying d_lr = self.d_lr enc_lr = self.enc_lr dec_lr = self.dec_lr # Start with trained model if self.trained_model: start = int(self.trained_model.split('_')[0]) else: start = 0 # Start training start_time = time.time() for e in range(start, self.num_epochs): for i, (real_image, real_landmark) in enumerate(self.data_loader): #real_x: real image and real_l: conditional side image (landmark heatmap) real_x = real_image[0] real_label = real_image[1] real_l = real_landmark[0] # Sample fake labels randomly rand_idx = torch.randperm(real_label.size(0)) fake_label = real_label[rand_idx] real_y = self.one_hot(real_label, self.y_dim) fake_y = self.one_hot(fake_label, self.y_dim) # Convert tensor to variable real_x = self.to_var(real_x) real_l = self.to_var(real_l) real_y = self.to_var(real_y) fake_y = self.to_var(fake_y) real_label = self.to_var(real_label) fake_label = self.to_var(fake_label) #================== Train Discriminator ================== # # Input images (original image+side images) are concatenated src_output, cls_output = self.D(torch.cat([real_x, real_l], 1)) d_loss_real = -torch.mean(src_output) d_loss_cls = F.cross_entropy(cls_output, real_label) # Compute expression recognition accuracy on synthetic images if (i + 1) % self.log_step == 0: accuracies = self.calculate_accuracy( cls_output, real_label) log = [ "{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy() ] print('Recognition Acc: ') print(log) # Generate outputs and compute loss with fake generated images enc_feat = self.Enc(torch.cat([real_x, real_l], 1)) fake_x, fake_l = self.Dec(enc_feat, fake_y) fake_x = Variable(fake_x.data) fake_l = Variable(fake_l.data) src_output, cls_output = self.D(torch.cat([fake_x, fake_l], 1)) d_loss_fake = torch.mean(src_output) # Discriminator losses d_loss = self.lambda_cls * d_loss_cls + d_loss_real + d_loss_fake self.reset() d_loss.backward() self.d_optimizer.step() # Compute gradient penalty loss real = torch.cat([real_x, real_l], 1) fake = torch.cat([fake_x, fake_l], 1) alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real) interpolated = Variable(alpha * real.data + (1 - alpha) * fake.data, requires_grad=True) output, cls_output = self.D(interpolated) grad = torch.autograd.grad(outputs=output, inputs=interpolated, grad_outputs=torch.ones( output.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Gradient penalty loss d_loss = self.lambda_gp * d_loss_gp self.reset() d_loss.backward() self.d_optimizer.step() # Logging loss = {} loss['D/loss_real'] = d_loss_real.data[0] loss['D/loss_fake'] = d_loss_fake.data[0] loss['D/loss_cls'] = d_loss_cls.data[0] loss['D/loss_gp'] = d_loss_gp.data[0] # ================== Train Encoder-Decoder networks ================== # if (i + 1) % self.d_train_repeat == 0: # Original-to-target and target-to-original domain enc_feat = self.Enc(torch.cat([real_x, real_l], 1)) fake_x, fake_l = self.Dec(enc_feat, fake_y) src_output, cls_output = self.D( torch.cat([fake_x, fake_l], 1)) g_loss_fake = -torch.mean(src_output) #rec_feat = self.Enc(fake_x) rec_feat = self.Enc(torch.cat([fake_x, fake_l], 1)) rec_x, rec_l = self.Dec(rec_feat, real_y) # bidirectional loss of the images g_loss_rec_x = torch.mean(torch.abs(real_x - rec_x)) g_loss_rec_l = torch.mean(torch.abs(real_l - rec_l)) #bidirectional loss of the latent feature g_loss_feature = torch.mean(torch.abs(enc_feat - rec_feat)) #identity loss of the images g_loss_identity_x = torch.mean(torch.abs(real_x - fake_x)) g_loss_identity_l = torch.mean(torch.abs(real_l - fake_l)) # attribute classification loss for the fake generated images g_loss_cls = F.cross_entropy(cls_output, fake_label) # Backward + Optimize (generator (encoder-decoder) losses), we update decoder two times for each encoder update g_loss = g_loss_fake + self.lambda_bi * g_loss_rec_x + self.lambda_bi * g_loss_rec_l + self.lambda_bi * g_loss_feature + self.lambda_id * g_loss_identity_x + self.lambda_id * g_loss_identity_l + self.lambda_cls * g_loss_cls self.reset() g_loss.backward() self.enc_optimizer.step() self.dec_optimizer.step() self.dec_optimizer.step() # Logging Generator losses loss['G/loss_feature'] = g_loss_feature.data[0] loss['G/loss_identity_x'] = g_loss_identity_x.data[0] loss['G/loss_identity_l'] = g_loss_identity_l.data[0] loss['G/loss_rec_x'] = g_loss_rec_x.data[0] loss['G/loss_rec_l'] = g_loss_rec_l.data[0] loss['G/loss_fake'] = g_loss_fake.data[0] loss['G/loss_cls'] = g_loss_cls.data[0] # Print out log if (i + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( elapsed, e + 1, self.num_epochs, i + 1, iters_per_epoch) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) for tag, value in loss.items(): self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1) # Synthesize images if (i + 1) % self.sample_step == 0: fake_image_list = [sample_x] for sample_y in sample_y_list: enc_feat = self.Enc(torch.cat([sample_x, sample_l], 1)) sample_result, sample_landmark = self.Dec( enc_feat, sample_y) fake_image_list.append(sample_result) fake_images = torch.cat(fake_image_list, dim=3) save_image(self.denorm(fake_images.data), os.path.join( self.sample_path, '{}_{}_fake.png'.format(e + 1, i + 1)), nrow=1, padding=0) print('Generated images and saved into {}..!'.format( self.sample_path)) # Save checkpoints if (i + 1) % self.model_save_step == 0: torch.save( self.Enc.state_dict(), os.path.join(self.model_path, '{}_{}_Enc.pth'.format(e + 1, i + 1))) torch.save( self.Dec.state_dict(), os.path.join(self.model_path, '{}_{}_Dec.pth'.format(e + 1, i + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_path, '{}_{}_D.pth'.format(e + 1, i + 1))) # Decay learning rate if (e + 1) > (self.num_epochs - self.num_epochs_decay): d_lr -= (self.d_lr / float(self.num_epochs_decay)) enc_lr -= (self.enc_lr / float(self.num_epochs_decay)) dec_lr -= (self.dec_lr / float(self.num_epochs_decay)) self.update_lr(enc_lr, dec_lr, d_lr) print('Decay learning rate to enc_lr: {}, d_lr: {}.'.format( enc_lr, d_lr)) def test(self): """Generating face images owning target attributes (desired expressions) """ # Load trained models Enc_path = os.path.join(self.model_path, '{}_Enc.pth'.format(self.test_model)) Dec_path = os.path.join(self.model_path, '{}_Dec.pth'.format(self.test_model)) self.Enc.load_state_dict(torch.load(Enc_path)) self.Dec.load_state_dict(torch.load(Dec_path)) self.Enc.eval() self.Dec.eval() data_loader = self.face_data_loader for i, (real_image, real_landmark) in enumerate(data_loader): org_c = real_image[1] real_x = real_image[0] real_l = real_landmark[0] real_x = self.to_var(real_x, volatile=True) real_l = self.to_var(real_l, volatile=True) target_y_list = [] for j in range(self.y_dim): target_y = self.one_hot( torch.ones(real_x.size(0)) * j, self.y_dim) target_y_list.append(self.to_var(target_y, volatile=True)) # Target image generation fake_image_list = [real_x] for target_y in target_y_list: enc_feat = self.Enc(torch.cat([real_x, real_l], 1)) sample_result, sample_landmark = self.Dec(enc_feat, target_y) fake_image_list.append(sample_result) fake_images = torch.cat(fake_image_list, dim=3) save_path = os.path.join(self.test_path, '{}_fake.png'.format(i + 1)) save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0) print('Generated images and saved into "{}"..!'.format(save_path))
lr_scheduler.step() for i_batch, sampled_batch in enumerate(loader): data, target = sampled_batch if torch.cuda.is_available(): data, target = Variable(data).cuda(), Variable(target).cuda() else: data, target = Variable(data), Variable(target) optimizer.zero_grad() pred = net(data) loss = loss_fn(pred, target.float()) loss.backward() optimizer.step() logger.info('[epoch: {}, batch: {}] Training loss: {}'.format( epoch, i_batch, loss.data[0])) tb_logger.scalar_summary('loss', loss.data[0], epoch * niter_per_epoch + i_batch + 1) # (2) Log values and gradients of the parameters (histogram) for tag, value in net.named_parameters(): tag = tag.replace('.', '/') tb_logger.histo_summary(tag, value.data.cpu().numpy(), epoch + 1) tb_logger.histo_summary(tag + '/grad', value.grad.data.cpu().numpy(), epoch + 1) if (epoch + 1) % 10 == 0: cp_path = opj(CHECKPOINTS_PATH, cur_time, 'model_%s' % epoch) mkdir_r(dirname(cp_path)) torch.save(net.state_dict(), cp_path)
# Define loss function criterion = nn.NLLLoss() # Keep track of time elapsed and running averages start = time.time() # Set configuration for using Tensorboard logger = Logger('graphs') for step in range(step, final_steps + 1): # Get training data for this cycle inputs, targets, len_inputs, len_targets = train_corpus.next_batch() input_variable = Variable(torch.LongTensor(inputs), requires_grad=False) target_variable = Variable(torch.LongTensor(targets), requires_grad=False) if Config.use_cuda: input_variable = input_variable.cuda() target_variable = target_variable.cuda() # Run the train function loss = train(input_variable, len_inputs, target_variable, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion) # Keep track of loss logger.scalar_summary('loss', loss, step) if step % print_every == 0: print('%s: %s (%d %d%%)' % (step, time_since(start, 1. * step / final_steps), step, step / final_steps * 100)) if step % save_every == 0: save_state(encoder, decoder, encoder_optimizer, decoder_optimizer, step)
# Compute accuracy _, argmax = torch.max(outputs, 1) accuracy = (labels == argmax.squeeze()).float().mean() if (step + 1) % 100 == 0: print('Step [{}/{}], Loss: {:.4f}, Acc: {:.2f}'.format( step + 1, total_step, loss.item(), accuracy.item())) # ================================================================== # # Tensorboard Logging # # ================================================================== # # 1. Log scalar values (scalar summary) info = {'loss': loss.item(), 'accuracy': accuracy.item()} for tag, value in info.items(): logger.scalar_summary(tag, value, step + 1) # 2. Log values and gradients of the parameters (histogram summary) for tag, value in model.named_parameters(): tag = tag.replace('.', '/') logger.histo_summary(tag, value.data.cpu().numpy(), step + 1) logger.histo_summary(tag + '/grad', value.grad.data.cpu().numpy(), step + 1) # 3. Log training images (image summary) info = {'images': images.view(-1, 28, 28)[:10].cpu().numpy()} # [:10]:取前 9 张? for tag, images in info.items(): logger.image_summary(tag, images, step + 1)
class Solver(object): def __init__(self, data_loaders, config): # Data loader self.data_loaders = data_loaders self.attrs = config.attrs # Model hyper-parameters self.c_dim = len(data_loaders['train'].dataset.class_names) self.c2_dim = config.c2_dim self.image_size = config.image_size self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.g_repeat_num = config.g_repeat_num self.d_repeat_num = config.d_repeat_num self.d_train_repeat = config.d_train_repeat # Hyper-parameteres self.lambda_cls = config.lambda_cls self.lambda_rec = config.lambda_rec self.lambda_gp = config.lambda_gp self.g_lr = config.g_lr self.d_lr = config.d_lr self.beta1 = config.beta1 self.beta2 = config.beta2 # Training settings self.num_epochs = config.num_epochs self.num_epochs_decay = config.num_epochs_decay self.num_iters = config.num_iters self.num_iters_decay = config.num_iters_decay self.batch_size = config.batch_size self.use_tensorboard = config.use_tensorboard self.pretrained_model = config.pretrained_model self.pretrained_model_path = config.pretrained_model_path # Test settings self.test_model = config.test_model # Path self.log_path = os.path.join(config.output_path, 'logs', config.output_name) self.sample_path = os.path.join(config.output_path, 'samples', config.output_name) self.model_save_path = os.path.join(config.output_path, 'models', config.output_name) self.result_path = os.path.join(config.output_path, 'results', config.output_name) # Step size self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.num_val_imgs = config.num_val_imgs # Build tensorboard if use self.build_model() if self.use_tensorboard: self.build_tensorboard() # Start with trained model if self.pretrained_model: self.load_pretrained_model() def build_model(self): # self.G = UnetGenerator(3+self.c_dim) self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num, self.image_size) self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) # Optimizers self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) # Print networks self.print_network(self.G, 'G') self.print_network(self.D, 'D') if torch.cuda.is_available(): self.G.cuda() self.D.cuda() def print_network(self, model, name): num_params = 0 for p in model.parameters(): num_params += p.numel() print_logger.info('{} - {} - Number of parameters: {}'.format( name, model, num_params)) def load_pretrained_model(self): self.G.load_state_dict( torch.load( os.path.join(self.pretrained_model_path, '{}_G.pth'.format(self.pretrained_model)))) self.D.load_state_dict( torch.load( os.path.join(self.pretrained_model_path, '{}_D.pth'.format(self.pretrained_model)))) print_logger.info('loaded trained models (step: {})..!'.format( self.pretrained_model)) def build_tensorboard(self): from tensorboard_logger import Logger self.logger = Logger(self.log_path) def update_lr(self, g_lr, d_lr): for param_group in self.g_optimizer.param_groups: param_group['lr'] = g_lr for param_group in self.d_optimizer.param_groups: param_group['lr'] = d_lr def reset_grad(self): self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() def to_var(self, x, grad=True): if torch.cuda.is_available(): x = x.cuda() return Variable(x, requires_grad=grad) def denorm(self, x): out = (x + 1) / 2 return out.clamp_(0, 1) def threshold(self, x): x = x.clone() x = (x >= 0.5).float() return x def compute_accuracy(self, x, y): x = F.sigmoid(x) predicted = self.threshold(x) correct = (predicted == y).float() accuracy = torch.mean(correct, dim=0) * 100.0 return accuracy def one_hot(self, labels, dim): """Convert label indices to one-hot vector""" batch_size = labels.size(0) out = torch.zeros(batch_size, dim) out[np.arange(batch_size), labels.long()] = 1 return out def make_data_labels(self, real_c): """Generate domain labels for dataset for debugging/testing. """ y = [] for dim in range(self.c_dim): t = [0] * self.c_dim t[dim] = 1 y.append(torch.FloatTensor(t)) fixed_c_list = [] for i in range(self.c_dim): fixed_c = real_c.clone() for c in fixed_c: c[:self.c_dim] = y[i] fixed_c_list.append(self.to_var(fixed_c, grad=False)) return fixed_c_list def train(self): """Train StarGAN within a single dataset.""" # The number of iterations per epoch data_loader = self.data_loaders['train'] iters_per_epoch = len(data_loader) fixed_x = [] real_c = [] num_fixed_imgs = self.num_val_imgs for i in range(num_fixed_imgs): images, labels = self.data_loaders['val'].dataset.__getitem__(i) fixed_x.append(images.unsqueeze(0)) real_c.append(labels.unsqueeze(0)) # Fixed inputs and target domain labels for debugging fixed_x = torch.cat(fixed_x, dim=0) fixed_x = self.to_var(fixed_x, grad=False) real_c = torch.cat(real_c, dim=0) fixed_c_list = self.make_data_labels(real_c) # lr cache for decaying g_lr = self.g_lr d_lr = self.d_lr # Start with trained model if exists if self.pretrained_model: start = int(self.pretrained_model.split('_')[0]) else: start = 0 # Start training start_time = time.time() for e in range(start, self.num_epochs): for i, (real_x, real_label) in enumerate(data_loader): # Generat fake labels randomly (target domain labels) rand_idx = torch.randperm(real_label.size(0)) fake_label = real_label[rand_idx] real_c = real_label.clone() fake_c = fake_label.clone() # Convert tensor to variable real_x = self.to_var(real_x) real_c = self.to_var(real_c) # input for the generator fake_c = self.to_var(fake_c) real_label = self.to_var( real_label ) # this is same as real_c if dataset == 'CelebA' fake_label = self.to_var(fake_label) # ================== Train D ================== # # Compute loss with real images out_src, out_cls = self.D(real_x) d_loss_real = -torch.mean(out_src) d_loss_cls = F.binary_cross_entropy_with_logits( out_cls, real_label, size_average=False) / real_x.size(0) # Compute classification accuracy of the discriminator if (i + 1) % (self.log_step * 10) == 0: accuracies = self.compute_accuracy(out_cls, real_label) log = [ "{}: {:.2f}".format(attr, acc) for (attr, acc) in zip(data_loader.dataset.class_names, accuracies.data.cpu().numpy()) ] print_logger.info('Discriminator Accuracy: {}'.format(log)) # Compute loss with fake images fake_x = self.G(real_x, fake_c) fake_x = Variable(fake_x.data) out_src, out_cls = self.D(fake_x) d_loss_fake = torch.mean(out_src) # Backward + Optimize d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls self.reset_grad() d_loss.backward() self.d_optimizer.step() loss = {} loss['D/loss_real'] = d_loss_real.data.item() loss['D/loss_fake'] = d_loss_fake.data.item() loss['D/loss_cls'] = d_loss_cls.data.item() loss['D/loss'] = d_loss.data.item() # Compute gradient penalty alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) out, out_cls = self.D(interpolated) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones( out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging loss['D/loss_gp'] = d_loss_gp.data.item() # ================== Train G ================== # if (i + 1) % self.d_train_repeat == 0: # Original-to-target and target-to-original domain fake_x = self.G(real_x, fake_c) rec_x = self.G(fake_x, real_c) # Compute losses out_src, out_cls = self.D(fake_x) g_loss_fake = -torch.mean(out_src) g_loss_rec = torch.mean(torch.abs(real_x - rec_x)) g_loss_l1 = torch.mean(torch.abs(real_x - fake_x)) g_loss_cls = F.binary_cross_entropy_with_logits( out_cls, fake_label, size_average=False) / fake_x.size(0) # Backward + Optimize g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls #+ g_loss_l1 * self.lambda_rec self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging loss['G/loss_fake'] = g_loss_fake.data.item() loss['G/loss_rec'] = g_loss_rec.data.item() loss['G/loss_cls'] = g_loss_cls.data.item() # loss['G/loss_l1'] = g_loss_l1.data.item() loss['G/loss'] = g_loss.data.item() # Print out log info if (i + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( elapsed, e + 1, self.num_epochs, i + 1, iters_per_epoch) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print_logger.info(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary( tag, value, e * iters_per_epoch + i + 1) # Translate fixed images for debugging if (i + 1) % self.sample_step == 0: fake_image_list = [fixed_x] for fixed_c in fixed_c_list: gen_imgs = self.G(fixed_x, fixed_c) fake_image_list.append(gen_imgs) # fake_images = torch.cat(fake_image_list, dim=3) # save_image(self.denorm(fake_images.data.cpu()), # os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0) # print_logger.info('Translated images and saved into {}..!'.format(self.sample_path)) if self.use_tensorboard: tb_imgs = [t.unsqueeze(0) for t in fake_image_list] tb_imgs = torch.cat(tb_imgs) tb_imgs = tb_imgs.permute(1, 0, 2, 3, 4) tb_imgs_list = torch.unbind(tb_imgs, dim=0) tb_imgs_list = [ torch.cat(torch.unbind(t, dim=0), dim=2) for t in tb_imgs_list ] self.logger.image_summary('fixed_imgs', tb_imgs_list, e * iters_per_epoch + i + 1) # Save model checkpoints if (i + 1) % self.model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e + 1, i + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e + 1, i + 1))) # Decay learning rate if (e + 1) > (self.num_epochs - self.num_epochs_decay): g_lr -= (self.g_lr / float(self.num_epochs_decay)) d_lr -= (self.d_lr / float(self.num_epochs_decay)) self.update_lr(g_lr, d_lr) print_logger.info( 'Decay learning rate to g_lr: {}, d_lr: {}.'.format( g_lr, d_lr)) def test(self): """Facial attribute transfer on CelebA or facial expression synthesis on RaFD.""" # Load trained parameters G_path = os.path.join(self.model_save_path, '{}_G.pth'.format(self.test_model)) self.G.load_state_dict(torch.load(G_path)) self.G.eval() data_loader = self.data_loaders['test'] for i, (real_x, org_c) in enumerate(data_loader): real_x = self.to_var(real_x, grad=False) target_c_list = self.make_data_labels(org_c) # Start translations fake_image_list = [real_x] for target_c in target_c_list: fake_image_list.append(self.G(real_x, target_c)) fake_images = torch.cat(fake_image_list, dim=3) save_path = os.path.join(self.result_path, '{}_fake.png'.format(i + 1)) save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0) print_logger.info( 'Translated test images and saved into "{}"..!'.format( save_path))
def main(): global opt, best_studentprec1 cudnn.benchmark = True opt = parser.parse_args() opt.logdir = opt.logdir + '/' + opt.name logger = Logger(opt.logdir) print(opt) best_studentprec1 = 0.0 print('Loading models...') teacher = init.load_model(opt, 'teacher') student = init.load_model(opt, 'student') discriminator = init.load_model(opt, 'discriminator') teacher = init.setup(teacher, opt, 'teacher') student = init.setup(student, opt, 'student') discriminator = init.setup(discriminator, opt, 'discriminator') #Write the code to classify it in the 11th class print(teacher) print(student) print(discriminator) advCriterion = nn.BCELoss().cuda() similarityCriterion = nn.L1Loss().cuda() derivativeCriterion = nn.SmoothL1Loss().cuda() discclassifyCriterion = nn.CrossEntropyLoss(size_average=True).cuda() studOptim = getOptim(opt, student, 'student') discrecOptim = getOptim(opt, discriminator, 'discriminator') trainer = train.Trainer(student, teacher, discriminator, discclassifyCriterion, advCriterion, similarityCriterion, derivativeCriterion, studOptim, discrecOptim, opt, logger) validator = train.Validator(student, teacher, discriminator, opt, logger) #To update. Does not work as of now if opt.resume: if os.path.isfile(opt.resume): model, optimizer, opt, best_prec1 = init.resumer( opt, model, optimizer) else: print("=> no checkpoint found at '{}'".format(opt.resume)) dataloader = init_data.load_data(opt) train_loader = dataloader.train_loader val_loader = dataloader.val_loader for epoch in range(opt.start_epoch, opt.epochs): utils.adjust_learning_rate(opt, studOptim, epoch) utils.adjust_learning_rate(opt, discrecOptim, epoch) print("Starting epoch number:", epoch + 1, "Learning rate:", studOptim.param_groups[0]["lr"]) if opt.testOnly == False: trainer.train(train_loader, epoch, opt) if opt.tensorboard: logger.scalar_summary('learning_rate', opt.lr, epoch) student_prec1 = validator.validate(val_loader, epoch, opt) best_studentprec1 = max(student_prec1, best_studentprec1) init.save_checkpoint(opt, teacher, student, discriminator, studOptim, discrecOptim, student_prec1, epoch) print('Best accuracy: [{0:.3f}]\t'.format(best_studentprec1))