def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) discriminator = Discriminator() discriminator.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: raise Exception('This is not implemented') encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) """ variable definition """ X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim) """ loss definitions """ cross_entropy_loss = nn.CrossEntropyLoss() adversarial_loss = nn.BCELoss() ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() discriminator.cuda() cross_entropy_loss.cuda() adversarial_loss.cuda() X_1 = X_1.cuda() X_2 = X_2.cuda() X_3 = X_3.cuda() style_latent_space = style_latent_space.cuda() """ optimizer and scheduler definition """ auto_encoder_optimizer = optim.Adam( list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) reverse_cycle_optimizer = optim.Adam( list(encoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) generator_optimizer = optim.Adam( list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) discriminator_optimizer = optim.Adam( list(discriminator.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) # divide the learning rate by a factor of 10 after 80 epochs auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer, step_size=80, gamma=0.1) reverse_cycle_scheduler = optim.lr_scheduler.StepLR(reverse_cycle_optimizer, step_size=80, gamma=0.1) generator_scheduler = optim.lr_scheduler.StepLR(generator_optimizer, step_size=80, gamma=0.1) discriminator_scheduler = optim.lr_scheduler.StepLR(discriminator_optimizer, step_size=80, gamma=0.1) # Used later to define discriminator ground truths Tensor = torch.cuda.FloatTensor if FLAGS.cuda else torch.FloatTensor """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") if not os.path.exists('checkpoints'): os.makedirs('checkpoints') if not os.path.exists('reconstructed_images'): os.makedirs('reconstructed_images') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: headers = ['Epoch', 'Iteration', 'Reconstruction_loss', 'KL_divergence_loss', 'Reverse_cycle_loss'] if FLAGS.forward_gan: headers.extend(['Generator_forward_loss', 'Discriminator_forward_loss']) if FLAGS.reverse_gan: headers.extend(['Generator_reverse_loss', 'Discriminator_reverse_loss']) log.write('\t'.join(headers) + '\n') # load data set and create data loader instance print('Loading CIFAR paired dataset...') paired_cifar = CIFAR_Paired(root='cifar', download=True, train=True, transform=transform_config) loader = cycle(DataLoader(paired_cifar, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # Save a batch of images to use for visualization image_sample_1, image_sample_2, _ = next(loader) image_sample_3, _, _ = next(loader) # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print('Epoch #' + str(epoch) + '..........................................................................') # update the learning rate scheduler auto_encoder_scheduler.step() reverse_cycle_scheduler.step() generator_scheduler.step() discriminator_scheduler.step() for iteration in range(int(len(paired_cifar) / FLAGS.batch_size)): # Adversarial ground truths valid = Variable(Tensor(FLAGS.batch_size, 1).fill_(1.0), requires_grad=False) fake = Variable(Tensor(FLAGS.batch_size, 1).fill_(0.0), requires_grad=False) # A. run the auto-encoder reconstruction image_batch_1, image_batch_2, _ = next(loader) auto_encoder_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) style_mu_1, style_logvar_1, class_latent_space_1 = encoder(Variable(X_1)) style_latent_space_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1) kl_divergence_loss_1 = FLAGS.kl_divergence_coef * ( - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp()) ) kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_1.backward(retain_graph=True) style_mu_2, style_logvar_2, class_latent_space_2 = encoder(Variable(X_2)) style_latent_space_2 = reparameterize(training=True, mu=style_mu_2, logvar=style_logvar_2) kl_divergence_loss_2 = FLAGS.kl_divergence_coef * ( - 0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) - style_logvar_2.exp()) ) kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_2.backward(retain_graph=True) reconstructed_X_1 = decoder(style_latent_space_1, class_latent_space_2) reconstructed_X_2 = decoder(style_latent_space_2, class_latent_space_1) reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_1, Variable(X_1)) reconstruction_error_1.backward(retain_graph=True) reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_2, Variable(X_2)) reconstruction_error_2.backward() reconstruction_error = (reconstruction_error_1 + reconstruction_error_2) / FLAGS.reconstruction_coef kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2) / FLAGS.kl_divergence_coef auto_encoder_optimizer.step() # A-1. Discriminator training during forward cycle if FLAGS.forward_gan: # Training generator generator_optimizer.zero_grad() g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid) g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid) gen_f_loss = (g_loss_1 + g_loss_2) / 2.0 gen_f_loss.backward() generator_optimizer.step() # Training discriminator discriminator_optimizer.zero_grad() real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid) real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid) fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake) fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake) dis_f_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0 dis_f_loss.backward() discriminator_optimizer.step() # B. reverse cycle image_batch_1, _, __ = next(loader) image_batch_2, _, __ = next(loader) reverse_cycle_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) style_latent_space.normal_(0., 1.) _, __, class_latent_space_1 = encoder(Variable(X_1)) _, __, class_latent_space_2 = encoder(Variable(X_2)) reconstructed_X_1 = decoder(Variable(style_latent_space), class_latent_space_1.detach()) reconstructed_X_2 = decoder(Variable(style_latent_space), class_latent_space_2.detach()) style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1) style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1) style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2) style_latent_space_2 = reparameterize(training=False, mu=style_mu_2, logvar=style_logvar_2) reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(style_latent_space_1, style_latent_space_2) reverse_cycle_loss.backward() reverse_cycle_loss /= FLAGS.reverse_cycle_coef reverse_cycle_optimizer.step() # B-1. Discriminator training during reverse cycle if FLAGS.reverse_gan: # Training generator generator_optimizer.zero_grad() g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid) g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid) gen_r_loss = (g_loss_1 + g_loss_2) / 2.0 gen_r_loss.backward() generator_optimizer.step() # Training discriminator discriminator_optimizer.zero_grad() real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid) real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid) fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake) fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake) dis_r_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0 dis_r_loss.backward() discriminator_optimizer.step() if (iteration + 1) % 10 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0])) print('Reverse cycle loss: ' + str(reverse_cycle_loss.data.storage().tolist()[0])) if FLAGS.forward_gan: print('Generator F loss: ' + str(gen_f_loss.data.storage().tolist()[0])) print('Discriminator F loss: ' + str(dis_f_loss.data.storage().tolist()[0])) if FLAGS.reverse_gan: print('Generator R loss: ' + str(gen_r_loss.data.storage().tolist()[0])) print('Discriminator R loss: ' + str(dis_r_loss.data.storage().tolist()[0])) # write to log with open(FLAGS.log_file, 'a') as log: row = [] row.append(epoch) row.append(iteration) row.append(reconstruction_error.data.storage().tolist()[0]) row.append(kl_divergence_error.data.storage().tolist()[0]) row.append(reverse_cycle_loss.data.storage().tolist()[0]) if FLAGS.forward_gan: row.append(gen_f_loss.data.storage().tolist()[0]) row.append(dis_f_loss.data.storage().tolist()[0]) if FLAGS.reverse_gan: row.append(gen_r_loss.data.storage().tolist()[0]) row.append(dis_r_loss.data.storage().tolist()[0]) row = [str(x) for x in row] log.write('\t'.join(row) + '\n') # write to tensorboard writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Reverse cycle loss', reverse_cycle_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) if FLAGS.forward_gan: writer.add_scalar('Generator F loss', gen_f_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Discriminator F loss', dis_f_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) if FLAGS.reverse_gan: writer.add_scalar('Generator R loss', gen_r_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Discriminator R loss', dis_r_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) # save model after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save)) """ save reconstructed images and style swapped image generations to check progress """ X_1.copy_(image_sample_1) X_2.copy_(image_sample_2) X_3.copy_(image_sample_3) style_mu_1, style_logvar_1, _ = encoder(Variable(X_1)) _, __, class_latent_space_2 = encoder(Variable(X_2)) style_mu_3, style_logvar_3, _ = encoder(Variable(X_3)) style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1) style_latent_space_3 = reparameterize(training=False, mu=style_mu_3, logvar=style_logvar_3) reconstructed_X_1_2 = decoder(style_latent_space_1, class_latent_space_2) reconstructed_X_3_2 = decoder(style_latent_space_3, class_latent_space_2) # save input image batch image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1)) if FLAGS.num_channels == 1: image_batch = np.concatenate((image_batch, image_batch, image_batch), axis=3) imshow_grid(image_batch, name=str(epoch) + '_original', save=True) # save reconstructed batch reconstructed_x = np.transpose(reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1)) if FLAGS.num_channels == 1: reconstructed_x = np.concatenate((reconstructed_x, reconstructed_x, reconstructed_x), axis=3) imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True) style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1)) if FLAGS.num_channels == 1: style_batch = np.concatenate((style_batch, style_batch, style_batch), axis=3) imshow_grid(style_batch, name=str(epoch) + '_style', save=True) # save style swapped reconstructed batch reconstructed_style = np.transpose(reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1)) if FLAGS.num_channels == 1: reconstructed_style = np.concatenate((reconstructed_style, reconstructed_style, reconstructed_style), axis=3) imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--train_imgs', type=str, help='dataset path') parser.add_argument('--mask_imgs', type=str, help='dataset path') parser.add_argument('--log_dir', type=str, default='log', help='Name of the log folder') parser.add_argument('--save_models', type=bool, default=True, help='Set True if you want to save trained models') parser.add_argument('--pre_trained_model_path', type=str, default=None, help='Pre-trained model path') parser.add_argument('--pre_trained_model_epoch', type=str, default=None, help='Pre-trained model epoch e.g 200') parser.add_argument('--train_imgs_path', type=str, default='C:/Users/motur/coco/images/train2017', help='Path to training images') parser.add_argument( '--train_annotation_path', type=str, default='C:/Users/motur/coco/annotations/instances_train2017.json', help='Path to annotation file, .json file') parser.add_argument('--category_names', type=str, default='giraffe,elephant,zebra,sheep,cow,bear', help='List of categories in MS-COCO dataset') parser.add_argument('--num_test_img', type=int, default=16, help='Number of images saved during training') parser.add_argument('--img_size', type=int, default=256, help='Generated image size') parser.add_argument( '--local_patch_size', type=int, default=256, help='Image size of instance images after interpolation') parser.add_argument('--batch_size', type=int, default=16, help='Mini-batch size') parser.add_argument('--train_epoch', type=int, default=20, help='Maximum training epoch') parser.add_argument('--lr', type=float, default=0.0002, help='Initial learning rate') parser.add_argument('--optim_step_size', type=int, default=80, help='Learning rate decay step size') parser.add_argument('--optim_gamma', type=float, default=0.5, help='Learning rate decay ratio') parser.add_argument( '--critic_iter', type=int, default=5, help='Number of discriminator update against each generator update') parser.add_argument('--noise_size', type=int, default=128, help='Noise vector size') parser.add_argument('--lambda_FM', type=float, default=1, help='Trade-off param for feature matching loss') parser.add_argument('--lambda_recon', type=float, default=0.00001, help='Trade-off param for reconstruction loss') parser.add_argument('--num_res_blocks', type=int, default=5, help='Number of residual block in generator network') parser.add_argument( '--trade_off_G', type=float, default=0.1, help= 'Trade-off parameter which controls gradient flow to generator from D_local and D_glob' ) opt = parser.parse_args() print(opt) #Create log folder root = 'result_fg/' + opt.category_names + '/' model = 'coco_model_' result_folder_name = 'images_' + opt.log_dir model_folder_name = 'models_' + opt.log_dir if not os.path.isdir(root): os.makedirs(root) if not os.path.isdir(root + result_folder_name): os.makedirs(root + result_folder_name) if not os.path.isdir(root + model_folder_name): os.makedirs(root + model_folder_name) #Save the script copyfile(os.path.basename(__file__), root + result_folder_name + '/' + os.path.basename(__file__)) #Define transformation for dataset images - e.g scaling transform = transforms.Compose([ transforms.Scale((opt.img_size, opt.img_size)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) #Load dataset category_names = opt.category_names.split(',') allmasks = sorted( glob.glob(os.path.join(opt.mask_imgs, '**', '*.png'), recursive=True)) print('Number of masks: %d' % len(allmasks)) dataset = chairs(imfile=opt.train_imgs, mfiles=allmasks, category_names=category_names, transform=transform, final_img_size=opt.img_size) #Discard images contain very small instances # dataset.discard_small(min_area=0.03, max_area=1) #Define data loader train_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True) #For evaluation define fixed masks and noises data_iter = iter(train_loader) sample_batched = data_iter.next() x_fixed = sample_batched['image'][0:opt.num_test_img] x_fixed = Variable(x_fixed.cuda()) y_fixed = sample_batched['single_fg_mask'][0:opt.num_test_img] y_fixed = Variable(y_fixed.cuda()) z_fixed = torch.randn((opt.num_test_img, opt.noise_size)) z_fixed = Variable(z_fixed.cuda()) #Define networks G_fg = Generator_FG(z_dim=opt.noise_size, label_channel=len(category_names), num_res_blocks=opt.num_res_blocks) D_glob = Discriminator(channels=3 + len(category_names)) D_instance = Discriminator(channels=3 + len(category_names), input_size=opt.local_patch_size) G_fg.cuda() D_glob.cuda() D_instance.cuda() #Load parameters from pre-trained models if opt.pre_trained_model_path != None and opt.pre_trained_model_epoch != None: try: G_fg.load_state_dict( torch.load(opt.pre_trained_model_path + 'G_fg_epoch_' + opt.pre_trained_model_epoch)) D_glob.load_state_dict( torch.load(opt.pre_trained_model_path + 'D_glob_epoch_' + opt.pre_trained_model_epoch)) D_instance.load_state_dict( torch.load(opt.pre_trained_model_path + 'D_local_epoch_' + opt.pre_trained_model_epoch)) print('Parameters are loaded!') except: print('Error: Pre-trained parameters are not loaded!') pass #Define interpolation operation up_instance = nn.Upsample(size=(opt.local_patch_size, opt.local_patch_size), mode='bilinear') #Define pooling operation for the case that image size and local patch size are mismatched pooling_instance = nn.Sequential() if opt.local_patch_size != opt.img_size: pooling_instance.add_module( '0', nn.AvgPool2d(int(opt.img_size / opt.local_patch_size))) #Define training loss function - binary cross entropy BCE_loss = nn.BCELoss() #Define feature matching loss criterionVGG = VGGLoss() criterionVGG = criterionVGG.cuda() #Define optimizer G_local_optimizer = optim.Adam(G_fg.parameters(), lr=opt.lr, betas=(0.0, 0.9)) D_local_optimizer = optim.Adam( list(filter(lambda p: p.requires_grad, D_glob.parameters())) + list(filter(lambda p: p.requires_grad, D_instance.parameters())), lr=opt.lr, betas=(0.0, 0.9)) #Deine learning rate scheduler scheduler_G = lr_scheduler.StepLR(G_local_optimizer, step_size=opt.optim_step_size, gamma=opt.optim_gamma) scheduler_D = lr_scheduler.StepLR(D_local_optimizer, step_size=opt.optim_step_size, gamma=opt.optim_gamma) #----------------------------TRAIN----------------------------------------- print('training start!') start_time = time.time() for epoch in range(opt.train_epoch): epoch_start_time = time.time() scheduler_G.step() scheduler_D.step() D_local_losses = [] G_local_losses = [] y_real_ = torch.ones(opt.batch_size) y_fake_ = torch.zeros(opt.batch_size) y_real_, y_fake_ = Variable(y_real_.cuda()), Variable(y_fake_.cuda()) data_iter = iter(train_loader) num_iter = 0 while num_iter < len(train_loader): j = 0 while j < opt.critic_iter and num_iter < len(train_loader): j += 1 sample_batched = data_iter.next() num_iter += 1 x_ = sample_batched['image'] y_ = sample_batched['single_fg_mask'] fg_mask = sample_batched['seg_mask'] y_instances = sample_batched['mask_instance'] bbox = sample_batched['bbox'] mini_batch = x_.size()[0] if mini_batch != opt.batch_size: break #Update discriminators - D #Real examples D_glob.zero_grad() D_instance.zero_grad() x_, y_ = Variable(x_.cuda()), Variable(y_.cuda()) fg_mask = Variable(fg_mask.cuda()) y_reduced = torch.sum(y_, 1).clamp(0, 1).view(y_.size(0), 1, opt.img_size, opt.img_size) x_d = torch.cat([x_, fg_mask], 1) x_instances = torch.zeros( (opt.batch_size, 3, opt.local_patch_size, opt.local_patch_size)) x_instances = Variable(x_instances.cuda()) y_instances = Variable(y_instances.cuda()) y_instances = pooling_instance(y_instances) G_instances = torch.zeros( (opt.batch_size, 3, opt.local_patch_size, opt.local_patch_size)) G_instances = Variable(G_instances.cuda()) #Obtain instances for t in range(x_d.size()[0]): x_instance = x_[t, 0:3, bbox[0][t]:bbox[1][t], bbox[2][t]:bbox[3][t]] x_instance = x_instance.contiguous().view( 1, x_instance.size()[0], x_instance.size()[1], x_instance.size()[2]) x_instances[t] = up_instance(x_instance) D_result_instance = D_instance( torch.cat([x_instances, y_instances], 1)).squeeze() D_result = D_glob(x_d).squeeze() D_real_loss = BCE_loss(D_result, y_real_) + BCE_loss( D_result_instance, y_real_) D_real_loss.backward() #Fake examples z_ = torch.randn((mini_batch, opt.noise_size)) z_ = Variable(z_.cuda()) #Generate fake images G_fg_result = G_fg(z_, y_, torch.mul(x_, (1 - y_reduced))) G_result_d = torch.cat([G_fg_result, fg_mask], 1) #Obtain fake instances for t in range(x_d.size()[0]): G_instance = G_result_d[t, 0:3, bbox[0][t]:bbox[1][t], bbox[2][t]:bbox[3][t]] G_instance = G_instance.contiguous().view( 1, G_instance.size()[0], G_instance.size()[1], G_instance.size()[2]) G_instances[t] = up_instance(G_instance) D_result_instance = D_instance( torch.cat([G_instances, y_instances], 1).detach()).squeeze() D_result = D_glob(G_result_d.detach()).squeeze() D_fake_loss = BCE_loss(D_result, y_fake_) + BCE_loss( D_result_instance, y_fake_) D_fake_loss.backward() D_local_optimizer.step() D_train_loss = D_real_loss + D_fake_loss D_local_losses.append(D_train_loss.data) if mini_batch != opt.batch_size: break #Update generator G G_fg.zero_grad() D_result = D_glob(G_result_d).squeeze() D_result_instance = D_instance( torch.cat([G_instances, y_instances], 1)).squeeze() G_train_loss = (1 - opt.trade_off_G) * BCE_loss( D_result, y_real_) + opt.trade_off_G * BCE_loss( D_result_instance, y_real_) #Feature matching loss between generated image and corresponding ground truth FM_loss = criterionVGG(G_fg_result, x_) #Reconstruction loss Recon_loss = mse_loss(torch.mul(x_, (1 - y_reduced)), torch.mul(G_fg_result, (1 - y_reduced))) total_loss = G_train_loss + opt.lambda_FM * FM_loss + opt.lambda_recon * Recon_loss total_loss.backward() G_local_optimizer.step() G_local_losses.append(G_train_loss.data) print('loss_d: %.3f, loss_g: %.3f' % (D_train_loss.data, G_train_loss.data)) if (num_iter % 100) == 0: print('%d - %d complete!' % ((epoch + 1), num_iter)) print(result_folder_name) epoch_end_time = time.time() per_epoch_ptime = epoch_end_time - epoch_start_time print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), opt.train_epoch, per_epoch_ptime, torch.mean(torch.FloatTensor(D_local_losses)), torch.mean(torch.FloatTensor(G_local_losses)))) #Save images G_fg.eval() if epoch == 0: show_result_rgb((epoch + 1), x_fixed, save=True, path=root + result_folder_name + '/' + model + str(epoch + 1) + '_gt.png') for t in range(y_fixed.size()[1]): show_result_rgb((epoch + 1), y_fixed[:, t:t + 1, :, :], save=True, path=root + result_folder_name + '/' + model + str(epoch + 1) + '_' + str(t) + '_masked.png') show_result_rgb( (epoch + 1), G_fg( z_fixed, y_fixed, torch.mul(x_fixed, (1 - torch.sum(y_fixed, 1).view( y_fixed.size(0), 1, opt.img_size, opt.img_size)))), save=True, path=root + result_folder_name + '/' + model + str(epoch + 1) + '_fg.png') G_fg.train() #Save model params if opt.save_models and (epoch > 11 and epoch % 10 == 0): torch.save( G_fg.state_dict(), root + model_folder_name + '/' + model + 'G_fg_epoch_' + str(epoch) + '.pth') torch.save( D_glob.state_dict(), root + model_folder_name + '/' + model + 'D_glob_epoch_' + str(epoch) + '.pth') torch.save( D_instance.state_dict(), root + model_folder_name + '/' + model + 'D_local_epoch_' + str(epoch) + '.pth') torch.save( G_fg.state_dict(), root + model_folder_name + '/' + model + 'G_fg_epoch_' + str(epoch) + '.pth') torch.save( D_glob.state_dict(), root + model_folder_name + '/' + model + 'D_glob_epoch_' + str(epoch) + '.pth') torch.save( D_instance.state_dict(), root + model_folder_name + '/' + model + 'D_local_epoch_' + str(epoch) + '.pth') end_time = time.time() total_ptime = end_time - start_time print("Training finish!... save training results") print('Training time: ' + str(total_ptime))
transforms.CenterCrop(opt.img_size), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ])), batch_size=opt.batch_size, shuffle=True) generator = Generator(opt.latent_dim, opt.channels) discriminator = Discriminator(opt.n_classes, opt.channels) adversarial_loss = nn.BCELoss() classifier_loss = nn.CrossEntropyLoss() if cuda: generator.cuda() discriminator.cuda() adversarial_loss.cuda() classifier_loss.cuda() torch.cuda.set_device(opt.gpu_ids) print(generator, discriminator) optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
def main(): parser = argparse.ArgumentParser() parser.add_argument('--log_dir', type=str, default='log', help='Name of the log folder') parser.add_argument('--save_models', type=bool, default=True, help='Set True if you want to save trained models') parser.add_argument('--pre_trained_model_path', type=str, default=None, help='Pre-trained model path') parser.add_argument('--pre_trained_model_epoch', type=str, default=None, help='Pre-trained model epoch e.g 200') parser.add_argument('--train_imgs_path', type=str, default='/mnt/sdb/data/COCO/train2017', help='Path to training images') parser.add_argument( '--train_annotation_path', type=str, default='/mnt/sdb/data/COCO/annotations/instances_train2017.json', help='Path to annotation file, .json file') parser.add_argument('--category_names', type=str, default='giraffe,elephant,zebra,sheep,cow,bear', help='List of categories in MS-COCO dataset') parser.add_argument('--num_test_img', type=int, default=4, help='Number of images saved during training') parser.add_argument('--img_size', type=int, default=256, help='Generated image size') parser.add_argument( '--local_patch_size', type=int, default=256, help='Image size of instance images after interpolation') parser.add_argument('--batch_size', type=int, default=4, help='Mini-batch size') parser.add_argument('--train_epoch', type=int, default=400, help='Maximum training epoch') parser.add_argument('--lr', type=float, default=0.0002, help='Initial learning rate') parser.add_argument('--optim_step_size', type=int, default=80, help='Learning rate decay step size') parser.add_argument('--optim_gamma', type=float, default=0.5, help='Learning rate decay ratio') parser.add_argument( '--critic_iter', type=int, default=5, help='Number of discriminator update against each generator update') parser.add_argument('--noise_size', type=int, default=256, help='Noise vector size') parser.add_argument('--lambda_FM', type=float, default=1, help='Trade-off param for feature matching loss') parser.add_argument('--lambda_branch', type=float, default=100, help='Trade-off param for reconstruction loss') parser.add_argument( '--num_res_blocks', type=int, default=2, help='Number of residual block in generator shared part') parser.add_argument('--num_res_blocks_fg', type=int, default=2, help='Number of residual block in non-bg branch') parser.add_argument('--num_res_blocks_bg', type=int, default=0, help='Number of residual block in generator bg branch') opt = parser.parse_args() print(opt) #Create log folder root = 'result_bg/' model = 'coco_model_' result_folder_name = 'images_' + opt.log_dir model_folder_name = 'models_' + opt.log_dir if not os.path.isdir(root): os.mkdir(root) if not os.path.isdir(root + result_folder_name): os.mkdir(root + result_folder_name) if not os.path.isdir(root + model_folder_name): os.mkdir(root + model_folder_name) #Save the script copyfile(os.path.basename(__file__), root + result_folder_name + '/' + os.path.basename(__file__)) #Define transformation for dataset images - e.g scaling transform = transforms.Compose([ transforms.Scale((opt.img_size, opt.img_size)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) #Load dataset category_names = opt.category_names.split(',') dataset = CocoData(root=opt.train_imgs_path, annFile=opt.train_annotation_path, category_names=category_names, transform=transform, final_img_size=opt.img_size) #Discard images contain very small instances dataset.discard_small(min_area=0.0, max_area=1) #dataset.discard_bad_examples('bad_examples_list.txt') #Define data loader train_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True) #For evaluation define fixed masks and noises data_iter = iter(train_loader) sample_batched = data_iter.next() y_fixed = sample_batched['seg_mask'][0:opt.num_test_img] y_fixed = Variable(y_fixed.cuda()) z_fixed = torch.randn((opt.num_test_img, opt.noise_size)) z_fixed = Variable(z_fixed.cuda()) #Define networks G_bg = Generator_BG(z_dim=opt.noise_size, label_channel=len(category_names), num_res_blocks=opt.num_res_blocks, num_res_blocks_fg=opt.num_res_blocks_fg, num_res_blocks_bg=opt.num_res_blocks_bg) D_glob = Discriminator(channels=3 + len(category_names), input_size=opt.img_size) G_bg.cuda() D_glob.cuda() #Load parameters from pre-trained models if opt.pre_trained_model_path != None and opt.pre_trained_model_epoch != None: try: G_bg.load_state_dict( torch.load(opt.pre_trained_model_path + 'G_bg_epoch_' + opt.pre_trained_model_epoch)) D_glob.load_state_dict( torch.load(opt.pre_trained_model_path + 'D_glob_epoch_' + opt.pre_trained_model_epoch)) print('Parameters are loaded!') except: print('Error: Pre-trained parameters are not loaded!') pass #Define training loss function - binary cross entropy BCE_loss = nn.BCELoss() #Define feature matching loss criterionVGG = VGGLoss() criterionVGG = criterionVGG.cuda() #Define optimizer G_local_optimizer = optim.Adam(G_bg.parameters(), lr=opt.lr, betas=(0.0, 0.9)) D_local_optimizer = optim.Adam(filter(lambda p: p.requires_grad, D_glob.parameters()), lr=opt.lr, betas=(0.0, 0.9)) #Deine learning rate scheduler scheduler_G = lr_scheduler.StepLR(G_local_optimizer, step_size=opt.optim_step_size, gamma=opt.optim_gamma) scheduler_D = lr_scheduler.StepLR(D_local_optimizer, step_size=opt.optim_step_size, gamma=opt.optim_gamma) #----------------------------TRAIN--------------------------------------- print('training start!') start_time = time.time() for epoch in range(opt.train_epoch): scheduler_G.step() scheduler_D.step() D_local_losses = [] G_local_losses = [] y_real_ = torch.ones(opt.batch_size) y_fake_ = torch.zeros(opt.batch_size) y_real_, y_fake_ = Variable(y_real_.cuda()), Variable(y_fake_.cuda()) epoch_start_time = time.time() data_iter = iter(train_loader) num_iter = 0 while num_iter < len(train_loader): j = 0 while j < opt.critic_iter and num_iter < len(train_loader): j += 1 sample_batched = data_iter.next() num_iter += 1 x_ = sample_batched['image'] y_ = sample_batched['seg_mask'] y_reduced = torch.sum(y_, 1).view(y_.size(0), 1, y_.size(2), y_.size(3)) y_reduced = torch.clamp(y_reduced, 0, 1) y_reduced = Variable(y_reduced.cuda()) #Update discriminators - D #Real examples D_glob.zero_grad() mini_batch = x_.size()[0] if mini_batch != opt.batch_size: y_real_ = torch.ones(mini_batch) y_fake_ = torch.zeros(mini_batch) y_real_, y_fake_ = Variable(y_real_.cuda()), Variable( y_fake_.cuda()) x_, y_ = Variable(x_.cuda()), Variable(y_.cuda()) x_d = torch.cat([x_, y_], 1) D_result = D_glob(x_d).squeeze() D_real_loss = BCE_loss(D_result, y_real_) D_real_loss.backward() #Fake examples z_ = torch.randn((mini_batch, opt.noise_size)) z_ = Variable(z_.cuda()) #Generate fake images G_result, G_result_bg = G_bg(z_, y_) G_result_d = torch.cat([G_result, y_], 1) D_result = D_glob(G_result_d.detach()).squeeze() D_fake_loss = BCE_loss(D_result, y_fake_) D_fake_loss.backward() D_local_optimizer.step() D_train_loss = D_real_loss + D_fake_loss D_local_losses.append(D_train_loss.item()) #Update generator G G_bg.zero_grad() D_result = D_glob(G_result_d).squeeze() G_train_loss = BCE_loss(D_result, y_real_) #Feature matching loss between generated image and corresponding ground truth FM_loss = criterionVGG(G_result, x_) #Branch-similar loss branch_sim_loss = mse_loss(torch.mul(G_result, (1 - y_reduced)), torch.mul(G_result_bg, (1 - y_reduced))) total_loss = G_train_loss + opt.lambda_FM * FM_loss + opt.lambda_branch * branch_sim_loss total_loss.backward() G_local_optimizer.step() G_local_losses.append(G_train_loss.item()) print('loss_d: %.3f, loss_g: %.3f' % (D_train_loss.item(), G_train_loss.item())) if (num_iter % 100) == 0: print('%d - %d complete!' % ((epoch + 1), num_iter)) print(result_folder_name) epoch_end_time = time.time() per_epoch_ptime = epoch_end_time - epoch_start_time print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), opt.train_epoch, per_epoch_ptime, torch.mean(torch.FloatTensor(D_local_losses)), torch.mean(torch.FloatTensor(G_local_losses)))) #Save images G_bg.eval() G_result, G_result_bg = G_bg(z_fixed, y_fixed) G_bg.train() if epoch % 10 == 0: for t in range(y_fixed.size()[1]): show_result((epoch + 1), y_fixed[:, t:t + 1, :, :], save=True, path=root + result_folder_name + '/' + model + str(epoch + 1) + '_masked.png') show_result((epoch + 1), G_result, save=True, path=root + result_folder_name + '/' + model + str(epoch + 1) + '.png') show_result((epoch + 1), G_result_bg, save=True, path=root + result_folder_name + '/' + model + str(epoch + 1) + '_bg.png') #Save model params if opt.save_models and (epoch > 21 and epoch % 10 == 0): torch.save( G_bg.state_dict(), root + model_folder_name + '/' + model + 'G_bg_epoch_' + str(epoch) + '.pth') torch.save( D_glob.state_dict(), root + model_folder_name + '/' + model + 'D_glob_epoch_' + str(epoch) + '.pth') end_time = time.time() total_ptime = end_time - start_time print("Training finish!... save training results") print('Training time: ' + str(total_ptime))
def train_gmm(opt, train_loader, model, board): model.cuda() discriminator = Discriminator() discriminator.cuda() model.train() discriminator.train() # criterion criterion = nn.BCELoss() criterionL1 = nn.L1Loss() criterionPSC = PSCLoss() # optimizer optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.5, 0.999)) optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(0.5, 0.999)) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float( opt.decay_step + 1)) #count = 0 #base_name = os.path.basename(opt.checkpoint) # save_dir = os.path.join(opt.result_dir, opt.datamode) # if not os.path.exists(save_dir): # os.makedirs(save_dir) # warp_cloth_dir = os.path.join(save_dir, 'warp-cloth') # if not os.path.exists(warp_cloth_dir): # os.makedirs(warp_cloth_dir) for step in range(opt.keep_step + opt.decay_step): iter_start_time = time.time() inputs = train_loader.next_batch() c_names = inputs['c_name'] im = inputs['image'].cuda() im_pose = inputs['pose_image'].cuda() im_h = inputs['head'].cuda() shape = inputs['shape'].cuda() agnostic = inputs['agnostic'].cuda() c = inputs['cloth'].cuda() cm = inputs['cloth_mask'].cuda() im_c = inputs['parse_cloth'].cuda() im_g = inputs['grid_image'].cuda() blank = inputs['blank'].cuda() grid, theta = model(agnostic, c) warped_cloth = F.grid_sample(c, grid, padding_mode='border') warped_mask = F.grid_sample(cm, grid, padding_mode='zeros') warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros') #if (count < 14222): # save_images(warped_cloth, c_names, warp_cloth_dir) # print(warped_cloth.size()[0]) # count+=warped_cloth.size()[0] visuals = [[im_h, shape, im_pose], [c, warped_cloth, im_c], [warped_grid, (warped_cloth + im) * 0.5, im]] discriminator_train_step(opt.batch_size, discriminator, model, optimizer_d, criterion, im_c, agnostic, c) res, loss, lossL1, lossPSC, lossGAN = generator_train_step( opt.batch_size, discriminator, model, optimizer, criterion, criterionL1, criterionPSC, blank, im_c, agnostic, c) if (step + 1) % opt.display_count == 0: board_add_images(board, 'combine', visuals, step + 1) board.add_scalar('lossL1', lossL1.item(), step + 1) board.add_scalar('lossPSC', lossPSC.item(), step + 1) board.add_scalar('lossGAN', lossGAN.item(), step + 1) board.add_scalar('loss', loss.item(), step + 1) t = time.time() - iter_start_time print('step: %8d, time: %.3f, lossL1: %4f' % (step + 1, t, lossL1.item()), flush=True) print('step: %8d, time: %.3f, lossPSC: %4f' % (step + 1, t, lossPSC.item()), flush=True) print('step: %8d, time: %.3f, lossGAN: %4f' % (step + 1, t, lossGAN.item()), flush=True) print('step: %8d, time: %.3f, loss: %4f' % (step + 1, t, loss.item()), flush=True) if (step + 1) % opt.save_count == 0: save_checkpoint( model, os.path.join(opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step + 1)))
class Trainer: def __init__(self, dataset_dir, generator_channels, discriminator_channels, nz, style_depth, lrs, betas, eps, phase_iter, weights_halflife, batch_size, n_cpu, opt_level): self.nz = nz self.dataloader = Dataloader(dataset_dir, batch_size, phase_iter * 2, n_cpu) self.generator = Generator(generator_channels, nz, style_depth).cuda() self.generator_ema = Generator(generator_channels, nz, style_depth).cuda() self.generator_ema.load_state_dict(copy.deepcopy(self.generator.state_dict())) self.discriminator = Discriminator(discriminator_channels).cuda() self.tb = tensorboard.tf_recorder('StyleGAN') self.phase_iter = phase_iter self.lrs = lrs self.betas = betas self.weights_halflife = weights_halflife self.opt_level = opt_level self.ema = None torch.backends.cuda.benchmark = True def generator_trainloop(self, batch_size, alpha): requires_grad(self.generator, True) requires_grad(self.discriminator, False) # mixing regularization if random.random() < 0.9: z = [torch.randn(batch_size, self.nz).cuda(), torch.randn(batch_size, self.nz).cuda()] else: z = torch.randn(batch_size, self.nz).cuda() fake = self.generator(z, alpha=alpha) d_fake = self.discriminator(fake, alpha=alpha) loss = F.softplus(-d_fake).mean() self.optimizer_g.zero_grad() with amp.scale_loss(loss, self.optimizer_g) as scaled_loss: scaled_loss.backward() self.optimizer_g.step() for name, param in self.generator.named_parameters(): if param.requires_grad: self.ema(name, param.data) return loss.item() def discriminator_trainloop(self, real, alpha): requires_grad(self.generator, False) requires_grad(self.discriminator, True) real.requires_grad = True self.optimizer_d.zero_grad() d_real = self.discriminator(real, alpha=alpha) loss_real = F.softplus(-d_real).mean() with amp.scale_loss(loss_real, self.optimizer_d) as scaled_loss_real: scaled_loss_real.backward(retain_graph=True) grad_real = grad( outputs=d_real.sum(), inputs=real, create_graph=True )[0] grad_penalty = ( grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2 ).mean() grad_penalty = 10 / 2 * grad_penalty with amp.scale_loss(grad_penalty, self.optimizer_d) as scaled_grad_penalty: scaled_grad_penalty.backward() if random.random() < 0.9: z = [torch.randn(real.size(0), self.nz).cuda(), torch.randn(real.size(0), self.nz).cuda()] else: z = torch.randn(real.size(0), self.nz).cuda() fake = self.generator(z, alpha=alpha) d_fake = self.discriminator(fake, alpha=alpha) loss_fake = F.softplus(d_fake).mean() with amp.scale_loss(loss_fake, self.optimizer_d) as scaled_loss_fake: scaled_loss_fake.backward() loss = scaled_loss_real + scaled_loss_fake + scaled_grad_penalty self.optimizer_d.step() return loss.item(), (d_real.mean().item(), d_fake.mean().item()) def run(self, log_iter, checkpoint): global_iter = 0 test_z = torch.randn(4, self.nz).cuda() self.ema = self.init_ema() if checkpoint: self.load_checkpoint(checkpoint) else: self.grow() while True: print('train {}X{} images...'.format(self.dataloader.img_size, self.dataloader.img_size)) for iter, ((data, _), n_trained_samples) in enumerate(tqdm(self.dataloader), 1): real = data.cuda() alpha = min(1, n_trained_samples / self.phase_iter) if self.dataloader.img_size > 8 else 1 loss_d, (real_score, fake_score) = self.discriminator_trainloop(real, alpha) loss_g = self.generator_trainloop(real.size(0), alpha) if global_iter % log_iter == 0: self.save_ema() self.log(loss_d, loss_g, real_score, fake_score, test_z, alpha) # save 3 times during training if iter % (len(self.dataloader) // 4 + 1) == 0: self.save_ema() self.save_checkpoint(n_trained_samples) global_iter += 1 self.tb.iter(data.size(0)) self.save_ema() self.save_checkpoint() self.grow() def save_ema(self): self.ema.set_weights(self.generator_ema) def init_ema(self): ema = EMA() for name, param in self.generator.named_parameters(): if param.requires_grad: ema.register(name, param.data) return ema def log(self, loss_d, loss_g, real_score, fake_score, test_z, alpha): with torch.no_grad(): fake = self.generator(test_z, alpha=alpha) fake = (fake + 1) * 0.5 fake = torch.clamp(fake, min=0.0, max=1.0) self.generator.cpu() self.generator_ema.cuda() fake_ema = self.generator_ema(test_z, alpha=alpha) fake_ema = (fake_ema + 1) * 0.5 fake_ema = torch.clamp(fake_ema, min=0.0, max=1.0) self.generator_ema.cpu() self.generator.cuda() self.tb.add_scalar('loss_d', loss_d) self.tb.add_scalar('loss_g', loss_g) self.tb.add_scalar('real_score', real_score) self.tb.add_scalar('fake_score', fake_score) self.tb.add_images('fake', fake) self.tb.add_images('fake_ema', fake_ema) def grow(self): self.discriminator.grow() self.generator.grow() self.generator_ema.grow() self.dataloader.grow() self.generator.cuda() self.discriminator.cuda() decay = 0.0 if self.weights_halflife > 0: decay = 0.5 ** (float(self.dataloader.batch_size) / self.weights_halflife) self.ema.grow(self.generator, decay) self.tb.renew('{}x{}'.format(self.dataloader.img_size, self.dataloader.img_size)) self.lr = self.lrs.get(str(self.dataloader.img_size), 0.001) self.style_lr = self.lr * 0.01 self.optimizer_d = optim.Adam(params=self.discriminator.parameters(), lr=self.lr, betas=self.betas) self.optimizer_g = optim.Adam([ {'params': self.generator.model.parameters(), 'lr': self.lr}, {'params': self.generator.style_mapper.parameters(), 'lr': self.style_lr}, ], betas=self.betas ) [self.generator, self.discriminator], [self.optimizer_g, self.optimizer_d] = amp.initialize( [self.generator, self.discriminator], [self.optimizer_g, self.optimizer_d], opt_level=self.opt_level ) def save_checkpoint(self, tick='last'): torch.save({ 'generator': self.generator.state_dict(), 'generator_ema': self.generator_ema.state_dict(), 'discriminator': self.discriminator.state_dict(), 'generator_optimizer': self.optimizer_g.state_dict(), 'discriminator_optimizer': self.optimizer_d.state_dict(), 'img_size': self.dataloader.img_size, 'tick': tick, }, 'checkpoints/{}x{}_{}.pth'.format(self.dataloader.img_size, self.dataloader.img_size, tick)) def load_checkpoint(self, filename): checkpoint = torch.load(filename) print('load {}x{} checkpoint'.format(checkpoint['img_size'], checkpoint['img_size'])) while self.dataloader.img_size < checkpoint['img_size']: self.grow() self.generator.load_state_dict(checkpoint['generator']) self.generator_ema.load_state_dict(checkpoint['generator_ema']) self.discriminator.load_state_dict(checkpoint['discriminator']) self.optimizer_g.load_state_dict(checkpoint['generator_optimizer']) self.optimizer_d.load_state_dict(checkpoint['discriminator_optimizer']) if checkpoint['tick'] == 'last': self.grow() else: self.dataloader.set_checkpoint(checkpoint['tick']) self.tb.iter(checkpoint['tick'])
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) discriminator = Discriminator() discriminator.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) discriminator.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.discriminator_save))) """ variable definition """ real_domain_labels = 1 fake_domain_labels = 0 X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) domain_labels = torch.LongTensor(FLAGS.batch_size) style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim) """ loss definitions """ cross_entropy_loss = nn.CrossEntropyLoss() ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() discriminator.cuda() cross_entropy_loss.cuda() X_1 = X_1.cuda() X_2 = X_2.cuda() X_3 = X_3.cuda() domain_labels = domain_labels.cuda() style_latent_space = style_latent_space.cuda() """ optimizer definition """ auto_encoder_optimizer = optim.Adam( list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) discriminator_optimizer = optim.Adam( list(discriminator.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) generator_optimizer = optim.Adam( list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") if not os.path.exists('checkpoints'): os.makedirs('checkpoints') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write('Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\t') log.write('Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n') # load data set and create data loader instance print('Loading MNIST paired dataset...') paired_mnist = MNIST_Paired(root='mnist', download=True, train=True, transform=transform_config) loader = cycle(DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # initialise variables discriminator_accuracy = 0. # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print('Epoch #' + str(epoch) + '..........................................................................') for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)): # A. run the auto-encoder reconstruction image_batch_1, image_batch_2, _ = next(loader) auto_encoder_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) style_mu_1, style_logvar_1, class_1 = encoder(Variable(X_1)) style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1) kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp()) kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_1.backward(retain_graph=True) _, __, class_2 = encoder(Variable(X_2)) reconstructed_X_1 = decoder(style_1, class_1) reconstructed_X_2 = decoder(style_1, class_2) reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1)) reconstruction_error_1.backward(retain_graph=True) reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_1)) reconstruction_error_2.backward() reconstruction_error = reconstruction_error_1 + reconstruction_error_2 kl_divergence_error = kl_divergence_loss_1 auto_encoder_optimizer.step() # B. run the generator for i in range(FLAGS.generator_times): generator_optimizer.zero_grad() image_batch_1, _, __ = next(loader) image_batch_3, _, __ = next(loader) domain_labels.fill_(real_domain_labels) X_1.copy_(image_batch_1) X_3.copy_(image_batch_3) style_mu_1, style_logvar_1, _ = encoder(Variable(X_1)) style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1) kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp()) kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_1.backward(retain_graph=True) _, __, class_3 = encoder(Variable(X_3)) reconstructed_X_1_3 = decoder(style_1, class_3) output_1 = discriminator(Variable(X_3), reconstructed_X_1_3) generator_error_1 = cross_entropy_loss(output_1, Variable(domain_labels)) generator_error_1.backward(retain_graph=True) style_latent_space.normal_(0., 1.) reconstructed_X_latent_3 = decoder(Variable(style_latent_space), class_3) output_2 = discriminator(Variable(X_3), reconstructed_X_latent_3) generator_error_2 = cross_entropy_loss(output_2, Variable(domain_labels)) generator_error_2.backward() generator_error = generator_error_1 + generator_error_2 kl_divergence_error += kl_divergence_loss_1 generator_optimizer.step() # C. run the discriminator for i in range(FLAGS.discriminator_times): discriminator_optimizer.zero_grad() # train discriminator on real data domain_labels.fill_(real_domain_labels) image_batch_1, _, __ = next(loader) image_batch_2, image_batch_3, _ = next(loader) X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) X_3.copy_(image_batch_3) real_output = discriminator(Variable(X_2), Variable(X_3)) discriminator_real_error = cross_entropy_loss(real_output, Variable(domain_labels)) discriminator_real_error.backward() # train discriminator on fake data domain_labels.fill_(fake_domain_labels) style_mu_1, style_logvar_1, _ = encoder(Variable(X_1)) style_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1) _, __, class_3 = encoder(Variable(X_3)) reconstructed_X_1_3 = decoder(style_1, class_3) fake_output = discriminator(Variable(X_3), reconstructed_X_1_3) discriminator_fake_error = cross_entropy_loss(fake_output, Variable(domain_labels)) discriminator_fake_error.backward() # total discriminator error discriminator_error = discriminator_real_error + discriminator_fake_error # calculate discriminator accuracy for this step target_true_labels = torch.cat((torch.ones(FLAGS.batch_size), torch.zeros(FLAGS.batch_size)), dim=0) if FLAGS.cuda: target_true_labels = target_true_labels.cuda() discriminator_predictions = torch.cat((real_output, fake_output), dim=0) _, discriminator_predictions = torch.max(discriminator_predictions, 1) discriminator_accuracy = (discriminator_predictions.data == target_true_labels.long() ).sum().item() / (FLAGS.batch_size * 2) if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy: discriminator_optimizer.step() if (iteration + 1) % 50 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0])) print('') print('Generator loss: ' + str(generator_error.data.storage().tolist()[0])) print('Discriminator loss: ' + str(discriminator_error.data.storage().tolist()[0])) print('Discriminator accuracy: ' + str(discriminator_accuracy)) print('..........') # write to log with open(FLAGS.log_file, 'a') as log: log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\n'.format( epoch, iteration, reconstruction_error.data.storage().tolist()[0], kl_divergence_error.data.storage().tolist()[0], generator_error.data.storage().tolist()[0], discriminator_error.data.storage().tolist()[0], discriminator_accuracy )) # write to tensorboard writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Generator loss', generator_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Discriminator loss', discriminator_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Discriminator accuracy', discriminator_accuracy * 100, epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) # save model after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save)) torch.save(discriminator.state_dict(), os.path.join('checkpoints', FLAGS.discriminator_save))
num_workers=opt.n_workers) print(len(data_loader)) G = Generator(opt) D = Discriminator(opt) G.apply(weight_init) D.apply(weight_init) print(G) print(D) if USE_CUDA: G = G.cuda() D = D.cuda() G_optim = torch.optim.Adam(G.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) D_optim = torch.optim.Adam(D.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) GAN_loss = nn.BCELoss() L1_loss = nn.L1Loss() total_step = 0 for epoch in range(opt.n_epoch): epoch += 1 for i, (input, real) in enumerate(data_loader):
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim) encoder.apply(weights_init) decoder = Decoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim) decoder.apply(weights_init) discriminator = Discriminator() discriminator.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) discriminator.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.discriminator_save))) """ variable definition """ real_domain_labels = 1 fake_domain_labels = 0 X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) domain_labels = torch.LongTensor(FLAGS.batch_size) """ loss definitions """ cross_entropy_loss = nn.CrossEntropyLoss() ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() discriminator.cuda() cross_entropy_loss.cuda() X_1 = X_1.cuda() X_2 = X_2.cuda() X_3 = X_3.cuda() domain_labels = domain_labels.cuda() """ optimizer definition """ auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) discriminator_optimizer = optim.Adam(list(discriminator.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) generator_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) if not os.path.exists('checkpoints'): os.makedirs('checkpoints') if not os.path.exists('reconstructed_images'): os.makedirs('reconstructed_images') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write('Epoch\tIteration\tReconstruction_loss\t') log.write( 'Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n') # load data set and create data loader instance print('Loading MNIST paired dataset...') paired_mnist = MNIST_Paired(root='mnist', download=True, train=True, transform=transform_config) loader = cycle( DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # initialise variables discriminator_accuracy = 0. # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print( 'Epoch #' + str(epoch) + '..........................................................................' ) for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)): # A. run the auto-encoder reconstruction image_batch_1, image_batch_2, labels_batch_1 = next(loader) auto_encoder_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) nv_1, nc_1 = encoder(Variable(X_1)) nv_2, nc_2 = encoder(Variable(X_2)) reconstructed_X_1 = decoder(nv_1, nc_2) reconstructed_X_2 = decoder(nv_2, nc_1) reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1)) reconstruction_error_1.backward(retain_graph=True) reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_2)) reconstruction_error_2.backward() reconstruction_error = reconstruction_error_1 + reconstruction_error_2 if FLAGS.train_auto_encoder: auto_encoder_optimizer.step() # B. run the adversarial part of the architecture # B. a) run the discriminator for i in range(FLAGS.discriminator_times): discriminator_optimizer.zero_grad() # train discriminator on real data domain_labels.fill_(real_domain_labels) image_batch_1, image_batch_2, labels_batch_1 = next(loader) X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) real_output = discriminator(Variable(X_1), Variable(X_2)) discriminator_real_error = FLAGS.disc_coef * cross_entropy_loss( real_output, Variable(domain_labels)) discriminator_real_error.backward() # train discriminator on fake data domain_labels.fill_(fake_domain_labels) image_batch_3, _, labels_batch_3 = next(loader) X_3.copy_(image_batch_3) nv_3, nc_3 = encoder(Variable(X_3)) # reconstruction is taking common factor from X_1 and varying factor from X_3 reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1]) fake_output = discriminator(Variable(X_1), reconstructed_X_3_1) discriminator_fake_error = FLAGS.disc_coef * cross_entropy_loss( fake_output, Variable(domain_labels)) discriminator_fake_error.backward() # total discriminator error discriminator_error = discriminator_real_error + discriminator_fake_error # calculate discriminator accuracy for this step target_true_labels = torch.cat((torch.ones( FLAGS.batch_size), torch.zeros(FLAGS.batch_size)), dim=0) if FLAGS.cuda: target_true_labels = target_true_labels.cuda() discriminator_predictions = torch.cat( (real_output, fake_output), dim=0) _, discriminator_predictions = torch.max( discriminator_predictions, 1) discriminator_accuracy = (discriminator_predictions.data == target_true_labels.long()).sum( ).item() / (FLAGS.batch_size * 2) if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy and FLAGS.train_discriminator: discriminator_optimizer.step() # B. b) run the generator for i in range(FLAGS.generator_times): generator_optimizer.zero_grad() image_batch_1, _, labels_batch_1 = next(loader) image_batch_3, __, labels_batch_3 = next(loader) domain_labels.fill_(real_domain_labels) X_1.copy_(image_batch_1) X_3.copy_(image_batch_3) nv_3, nc_3 = encoder(Variable(X_3)) # reconstruction is taking common factor from X_1 and varying factor from X_3 reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1]) output = discriminator(Variable(X_1), reconstructed_X_3_1) generator_error = FLAGS.gen_coef * cross_entropy_loss( output, Variable(domain_labels)) generator_error.backward() if FLAGS.train_generator: generator_optimizer.step() # print progress after 10 iterations if (iteration + 1) % 10 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('Generator loss: ' + str(generator_error.data.storage().tolist()[0])) print('') print('Discriminator loss: ' + str(discriminator_error.data.storage().tolist()[0])) print('Discriminator accuracy: ' + str(discriminator_accuracy)) print('..........') # write to log with open(FLAGS.log_file, 'a') as log: log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\n'.format( epoch, iteration, reconstruction_error.data.storage().tolist()[0], generator_error.data.storage().tolist()[0], discriminator_error.data.storage().tolist()[0], discriminator_accuracy)) # write to tensorboard writer.add_scalar( 'Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Generator loss', generator_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Discriminator loss', discriminator_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) # save model after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save)) torch.save(discriminator.state_dict(), os.path.join('checkpoints', FLAGS.discriminator_save)) """ save reconstructed images and style swapped image generations to check progress """ image_batch_1, image_batch_2, labels_batch_1 = next(loader) image_batch_3, _, __ = next(loader) X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) X_3.copy_(image_batch_3) nv_1, nc_1 = encoder(Variable(X_1)) nv_2, nc_2 = encoder(Variable(X_2)) nv_3, nc_3 = encoder(Variable(X_3)) reconstructed_X_1 = decoder(nv_1, nc_2) reconstructed_X_3_2 = decoder(nv_3, nc_2) # save input image batch image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1)) image_batch = np.concatenate( (image_batch, image_batch, image_batch), axis=3) imshow_grid(image_batch, name=str(epoch) + '_original', save=True) # save reconstructed batch reconstructed_x = np.transpose( reconstructed_X_1.cpu().data.numpy(), (0, 2, 3, 1)) reconstructed_x = np.concatenate( (reconstructed_x, reconstructed_x, reconstructed_x), axis=3) imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True) # save cross reconstructed batch style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1)) style_batch = np.concatenate( (style_batch, style_batch, style_batch), axis=3) imshow_grid(style_batch, name=str(epoch) + '_style', save=True) reconstructed_style = np.transpose( reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1)) reconstructed_style = np.concatenate( (reconstructed_style, reconstructed_style, reconstructed_style), axis=3) imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
dataset = datasets.MNIST('.', transform=transform, download=True) dataloader = data.DataLoader(dataset, batch_size=4) # model g = Generator() d = Discriminator() # losses gan_loss = GANLoss() # use is_cuda = torch.cuda.is_available() if is_cuda: g = g.cuda() d = d.cuda() # optimizer optim_G = optim.Adam(g.parameters()) optim_D = optim.Adam(d.parameters()) # train for epoch in range(num_epoch): total_batch = len(dataloader) for idx, (image, label) in enumerate(dataloader): d.train() g.train() # fake image 생성
print(y_fixed) x_fixed = sample_batched['seg_mask'][0:num_test_img] x_fixed = torch.squeeze(x_fixed) x_fixed = torch.sum(x_fixed,dim=1) x_fixed = x_fixed.view(x_fixed.size()[0],1,x_fixed.size()[1],x_fixed.size()[2]) x_fixed = Variable(x_fixed.cuda()) z_fixed = torch.randn((num_test_img, noise_size)) z_fixed= Variable(z_fixed.cuda()) #--------------------------Define Networks------------------------------------ G_local = Generator_Baseline_2(z_dim=noise_size, label_channel=len(category_names),num_res_blocks=num_res_blocks) D_local = Discriminator(channels=1+len(category_names), input_size=img_size) G_local.cuda() D_local.cuda() #Load parameters from pre-trained model if load_params: G_local.load_state_dict(torch.load('result_mask/models_local_'+str(log_numb)+'/coco_model_G_glob_epoch_'+str(epoch_bias)+'.pth')) D_local.load_state_dict(torch.load('result_mask/models_local_'+str(log_numb)+'/coco_model_D_glob_epoch_'+str(epoch_bias)+'.pth')) print('Parameters are loaded from logFile: models_local_' +str(log_numb) +' ---- Epoch: '+str(epoch_bias)) # Binary Cross Entropy loss if use_LSGAN_loss: BCE_loss= nn.MSELoss() else: BCE_loss = nn.BCELoss()