def main(): opt = get_opt() # Define the Generators, only G_A is used for testing netG_A = networks.Generator(opt.input_nc, opt.output_nc, opt.ngf, opt.n_res, opt.dropout) if opt.u_net: netG_A = networks.U_net(opt.input_nc, opt.output_nc, opt.ngf) if opt.cuda: netG_A.cuda() # Do not need to track the gradients during testing utils.set_requires_grad(netG_A, False) netG_A.eval() netG_A.load_state_dict(torch.load(opt.net_GA)) # Load the data transform = transforms.Compose([transforms.Resize((opt.sizeh, opt.sizew)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) dataloader = DataLoader(ImageDataset(opt.rootdir, transform=transform, mode='val'), batch_size=opt.batch_size, shuffle=False, num_workers=opt.n_cpu) Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor input_A = Tensor(opt.batch_size, opt.input_nc, opt.size, opt.size) for i, batch in enumerate(dataloader): name, image = batch real_A = input_A.copy_(image) fake_B = netG_A(real_A) batch_size = len(name) # Save the generated images for j in range(batch_size): image_name = name[j].split('/')[-1] path = 'generated_image/' + image_name utils.save_image(fake_B[j, :, :, :], path)
def train(self): # load training info filepath = self.args.checkpoint if filepath is not None: self._load(filepath) while self.epoch < self.args.epochs: self.model.train() self.epoch += 1 for iteration, batch in enumerate(self.train_dataloader): real_a, real_b = batch[0].to(self.device), batch[1].to(self.device) fake_b = self.generator(real_a) # update discriminator set_requires_grad(self.discriminator, True) self.optimizer_d.zero_grad() fake_ab = torch.cat((real_a, fake_b), 1) pred_fake = self.discriminator(fake_ab.detach()) loss_fake = self.criterion_gan(pred_fake, False) real_ab = torch.cat((real_a, real_b), 1) pred_real = self.discriminator(real_ab) loss_real = self.criterion_gan(pred_real, True) loss_d = (loss_fake + loss_real)/2 loss_d.backward() self.optimizer_d.step() # update generator set_requires_grad(self.discriminator, False) self.optimizer_g.zero_grad() fake_ab = torch.cat((real_a, fake_b), 1) pred_fake = self.discriminator(fake_ab) loss_gan = self.criterion_gan(pred_fake, True) loss_l1 = self.criterion_l1(fake_b, real_b) * self.args.lamb loss_g = loss_gan + loss_l1 loss_g.backward() self.optimizer_g.step() if (iteration+1)%self.args.print_loss_freq == 0: print('Epoch[{0}]({1}/{2} - Loss_D: {3}, Loss_G: {4}, Loss: {5}'.format( self.epoch, iteration+1, len(self.train_dataloader), loss_d.item(), loss_g.item(), loss_d.item()+loss_g.item() )) self.scheduler_g.step() self.scheduler_d.step() if self.epoch % self.args.save_freq == 0: self._validate() self._save(str(self.epoch)) self._save('last')
def optimize_parameters(self): self.forward() set_requires_grad(self.net, True) self.optimizer.zero_grad() self.loss_calculate() self.loss_all.backward() self.optimizer.step()
def run(self): iter_A = iter(self.dataloader_A) iter_B = iter(self.dataloader_B) iter_per_epoch = min(len(iter_A), len(iter_B)) for step in range(self.start_iter, self.start_iter + self.train_iters): if step % iter_per_epoch == 0: iter_A = iter(self.dataloader_A) iter_B = iter(self.dataloader_B) real_A = iter_A.next() real_B = iter_B.next() real_A, real_B = real_A.to(self.device), real_B.to(self.device) fake_B = self.G_AB(real_A) fake_A = self.G_BA(real_B) # train G set_requires_grad([self.D_A, self.D_B], False) self.optimizer_G.zero_grad() loss_G_AB, loss_G_BA, loss_cycle_A, loss_cycle_B, loss_idt_A, loss_idt_B = ( # noqa self.backward_G(real_A, real_B, fake_A, fake_B)) self.optimizer_G.step() # train D set_requires_grad([self.D_A, self.D_B], True) self.optimizer_D.zero_grad() loss_D_A = self.backward_D(self.D_A, real_A, fake_A) loss_D_B = self.backward_D(self.D_B, real_B, fake_B) self.optimizer_D.step() # logger if step % self.args.log_report_freq == 0: self.logger.info( '{} Train: Step {} Loss/G_AB {:.4f}'.format( self.args.exp_name, step, loss_G_AB)) # writer if step % self.args.scalar_report_freq == 0: self.writer.add_scalar('Train/Loss/G_AB', loss_G_AB, step) self.writer.add_scalar('Train/Loss/cycle_A', loss_cycle_A, step) self.writer.add_scalar('Train/Loss/idt_A', loss_idt_A, step) self.writer.add_scalar('Train/Loss/G_BA', loss_G_BA, step) self.writer.add_scalar('Train/Loss/cycle_B', loss_cycle_B, step) self.writer.add_scalar('Train/Loss/idt_B', loss_idt_B, step) self.writer.add_scalar('Train/Loss/D_A', loss_D_A, step) self.writer.add_scalar('Train/Loss/D_B', loss_D_B, step) if step % self.args.image_report_freq == 0: real_A = real_A[-1].detach().cpu() real_B = real_B[-1].detach().cpu() fake_A = fake_A[-1].detach().cpu() fake_B = fake_B[-1].detach().cpu() self.writer.add_image('Train/real_A', real_A, step) self.writer.add_image('Train/fake_A', fake_A, step) self.writer.add_image('Train/real_B', real_B, step) self.writer.add_image('Train/fake_B', fake_B, step)
def execute(all_models, envname, savepath, eval_hp=None): hp = get_hp(eval_hp) ### Inverse Model Evaluation ### # Set all models to eval for m in all_models: set_requires_grad(m, False) m.eval() # Get models model, c_model, actor = all_models # Set Env the viewer config config = { "visible": False, "init_width": hp["init_width"], "init_height": hp["init_height"], "go_fast": True } env = get_env(envname) env.reset() env.render(config=config) env.viewer_setup() for i in range(100): env.render(config=config) test_mode = hp["test_mode"] n_test_locs = get_n_test_locs(envname) if test_mode == "valid" else hp["n_test_locs"] # Generate test data test_data = [] for i in range(n_test_locs): if test_mode == "valid": test_data.append(set_test_loc(i, envname, env, config)) else: test_data.append(set_valid_loc(envname, env, config)) ### Planning ### if hp["run_inverse_model"]: run_inverse_model(env, config, test_data, n_test_locs, hp, actor) total_dist, total_success, n_test_locs = run_planning_and_inverse_model(env, envname, hp, n_test_locs, test_data, config, test_mode, model, c_model, actor, savepath) return total_dist, total_success, n_test_locs
def optimize_D(self): """Calculate GAN loss for the discriminator""" utils.set_requires_grad(self.D, True) self.optimizer_D.zero_grad() fake_D_in = torch.cat((self.real_in, self.fake_out), 1) pred_fake = self.D(fake_D_in.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) real_D_in = torch.cat((self.real_in, self.real_out), 1) pred_real = self.D(real_D_in) self.loss_D_real = self.criterionGAN(pred_real, True) self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() self.optimizer_D.step() utils.set_requires_grad(self.D, False)
def main(args): source_model = Net().to(device) source_model.load_state_dict(torch.load(args.MODEL_FILE)) source_model.eval() set_requires_grad(source_model, requires_grad=False) clf = source_model source_model = source_model.feature_extractor target_model = Net().to(device) target_model.load_state_dict(torch.load(args.MODEL_FILE)) target_model = target_model.feature_extractor target_clf = clf.classifier discriminator = nn.Sequential(nn.Linear(320, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 1)).to(device) half_batch = args.batch_size // 2 source_dataset = MNIST(config.DATA_DIR / 'mnist', train=True, download=True, transform=Compose([GrayscaleToRgb(), ToTensor()])) source_loader = DataLoader(source_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) target_dataset = MNISTM(train=False) target_loader = DataLoader(target_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) discriminator_optim = torch.optim.Adam(discriminator.parameters()) target_optim = torch.optim.Adam(target_model.parameters()) criterion = nn.BCEWithLogitsLoss() for epoch in range(1, args.epochs + 1): batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) total_loss = 0 total_accuracy = 0 target_label_accuracy = 0 for _ in trange(args.iterations, leave=False): # Train discriminator set_requires_grad(target_model, requires_grad=False) set_requires_grad(discriminator, requires_grad=True) for _ in range(args.k_disc): (source_x, _), (target_x, _) = next(batch_iterator) source_x, target_x = source_x.to(device), target_x.to(device) source_features = source_model(source_x).view( source_x.shape[0], -1) target_features = target_model(target_x).view( target_x.shape[0], -1) discriminator_x = torch.cat([source_features, target_features]) discriminator_y = torch.cat([ torch.ones(source_x.shape[0], device=device), torch.zeros(target_x.shape[0], device=device) ]) preds = discriminator(discriminator_x).squeeze() loss = criterion(preds, discriminator_y) discriminator_optim.zero_grad() loss.backward() discriminator_optim.step() total_loss += loss.item() total_accuracy += (( preds > 0).long() == discriminator_y.long()).float().mean().item() # Train classifier set_requires_grad(target_model, requires_grad=True) set_requires_grad(discriminator, requires_grad=False) for _ in range(args.k_clf): _, (target_x, target_labels) = next(batch_iterator) target_x = target_x.to(device) target_features = target_model(target_x).view( target_x.shape[0], -1) # flipped labels discriminator_y = torch.ones(target_x.shape[0], device=device) preds = discriminator(target_features).squeeze() loss = criterion(preds, discriminator_y) target_optim.zero_grad() loss.backward() target_optim.step() target_label_preds = target_clf(target_features) target_label_accuracy += (target_label_preds.cpu().max(1)[1] == target_labels).float().mean().item() mean_loss = total_loss / (args.iterations * args.k_disc) mean_accuracy = total_accuracy / (args.iterations * args.k_disc) target_mean_accuracy = target_label_accuracy / (args.iterations * args.k_clf) tqdm.write( f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, ' f'discriminator_accuracy={mean_accuracy:.4f}, target_accuracy={target_mean_accuracy:.4f}' ) # Create the full target model and save it clf.feature_extractor = target_model torch.save(clf.state_dict(), 'trained_models/adda.pt')
def main(args): clf_model = Net().to(device) clf_model.load_state_dict(torch.load(args.MODEL_FILE)) feature_extractor = clf_model.feature_extractor discriminator = clf_model.classifier critic = nn.Sequential(nn.Linear(320, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 1)).to(device) half_batch = args.batch_size // 2 source_dataset = MNIST(config.DATA_DIR / 'mnist', train=True, download=True, transform=Compose([GrayscaleToRgb(), ToTensor()])) source_loader = DataLoader(source_dataset, batch_size=half_batch, drop_last=True, shuffle=True, num_workers=0, pin_memory=True) target_dataset = MNISTM(train=False) target_loader = DataLoader(target_dataset, batch_size=half_batch, drop_last=True, shuffle=True, num_workers=0, pin_memory=True) critic_optim = torch.optim.Adam(critic.parameters(), lr=1e-4) clf_optim = torch.optim.Adam(clf_model.parameters(), lr=1e-4) clf_criterion = nn.CrossEntropyLoss() for epoch in range(1, args.epochs + 1): batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) total_loss = 0 total_accuracy = 0 for _ in trange(args.iterations, leave=False): (source_x, source_y), (target_x, _) = next(batch_iterator) # Train critic set_requires_grad(feature_extractor, requires_grad=False) set_requires_grad(critic, requires_grad=True) source_x, target_x = source_x.to(device), target_x.to(device) source_y = source_y.to(device) with torch.no_grad(): h_s = feature_extractor(source_x).data.view( source_x.shape[0], -1) h_t = feature_extractor(target_x).data.view( target_x.shape[0], -1) for _ in range(args.k_critic): gp = gradient_penalty(critic, h_s, h_t) critic_s = critic(h_s) critic_t = critic(h_t) wasserstein_distance = critic_s.mean() - critic_t.mean() critic_cost = -wasserstein_distance + args.gamma * gp critic_optim.zero_grad() critic_cost.backward() critic_optim.step() total_loss += critic_cost.item() # Train classifier set_requires_grad(feature_extractor, requires_grad=True) set_requires_grad(critic, requires_grad=False) for _ in range(args.k_clf): source_features = feature_extractor(source_x).view( source_x.shape[0], -1) target_features = feature_extractor(target_x).view( target_x.shape[0], -1) source_preds = discriminator(source_features) clf_loss = clf_criterion(source_preds, source_y) wasserstein_distance = critic(source_features).mean() - critic( target_features).mean() loss = clf_loss + args.wd_clf * wasserstein_distance clf_optim.zero_grad() loss.backward() clf_optim.step() mean_loss = total_loss / (args.iterations * args.k_critic) tqdm.write(f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}') torch.save(clf_model.state_dict(), 'trained_models/wdgrl.pt')
if __name__ == '__main__': args = parse_args() img_list = open(args.input_list, 'r').read().strip().split('\n') dataset = TestSaveDataset(img_list) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=8, pin_memory=True) net = SRNDeblurNet().cuda() set_requires_grad(net, False) last_epoch = load_model(net, args.resume, epoch=args.resume_epoch) psnr_list = [] output_list = [] tt = time() for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)): for k in batch: if 'img' in k: batch[k] = batch[k].cuda() batch[k].requires_grad = False y, _, _ = net(batch['img256'], batch['img128'], batch['img64']) y.detach_()
def main(args): if args.model == 'gta': model_file = './trained_models/gta_source.pt' out_file = './trained_models/gta_wdgrl.pt' out_ftrs = 4375 clf_model = GTANet().to(device) clf_model.load_state_dict(torch.load(model_file)) feature_extractor = clf_model.feature_extractor discriminator = clf_model.classifier elif args.model == 'gta-res': model_file = './trained_models/gta_res_source.pt' out_file = './trained_models/gta_res_wdgrl.pt' clf_model = GTARes18Net(9, pretrained=False).to(device) out_ftrs = clf_model.fc.in_features clf_model.load_state_dict(torch.load(model_file)) feature_extractor = clf_model.feature_extractor discriminator = clf_model.fc elif args.model == 'gta-vgg': model_file = './trained_models/gta_vgg_source.pt' out_file = './trained_models/gta_vgg_wdgrl.pt' clf_model = GTAVGG11Net(9, pretrained=False).to(device) out_ftrs = clf_model.classifier[0].in_features # should be 512 * 7 * 7 clf_model.load_state_dict(torch.load(model_file)) set_requires_grad(clf_model, False) feature_extractor = clf_model.feature_extractor discriminator = clf_model.classifier else: raise ValueError(f'Unknown model type {args.model}') critic = nn.Sequential( nn.Linear(out_ftrs, 64), nn.ReLU(), nn.Linear(64, 16), nn.ReLU(), nn.Linear(16, 1), ).to(device) half_batch = args.batch_size // 2 target_dataset = ImageFolder('./data', transform=Compose([ Resize((398, 224)), RandomCrop(224), RandomHorizontalFlip(), ToTensor(), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ])) target_loader = DataLoader(target_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) source_dataset = ImageFolder('./t_data', transform=Compose([ RandomCrop(224, pad_if_needed=True, padding_mode='reflect'), RandomHorizontalFlip(), ToTensor(), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ])) source_loader = DataLoader(source_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) critic_optim = torch.optim.Adam(critic.parameters(), lr=1e-4) clf_optim = torch.optim.Adam(clf_model.parameters(), lr=1e-4) clf_criterion = nn.CrossEntropyLoss() for epoch in range(1, args.epochs + 1): batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) total_loss = 0 for _ in trange(args.iterations, leave=False): (source_x, source_y), (target_x, _) = next(batch_iterator) # Train critic set_requires_grad(feature_extractor, requires_grad=False) set_requires_grad(critic, requires_grad=True) source_x, target_x = source_x.to(device), target_x.to(device) source_y = source_y.to(device) with torch.no_grad(): h_s = feature_extractor(source_x).data.view( source_x.shape[0], -1) h_t = feature_extractor(target_x).data.view( target_x.shape[0], -1) for _ in range(args.k_critic): gp = gradient_penalty(critic, h_s, h_t) critic_s = critic(h_s) critic_t = critic(h_t) wasserstein_distance = critic_s.mean() - critic_t.mean() critic_cost = -wasserstein_distance + args.gamma * gp critic_optim.zero_grad() critic_cost.backward() critic_optim.step() total_loss += critic_cost.item() # Train classifier set_requires_grad(feature_extractor, requires_grad=True) set_requires_grad(critic, requires_grad=False) for _ in range(args.k_clf): source_features = feature_extractor(source_x).view( source_x.shape[0], -1) target_features = feature_extractor(target_x).view( target_x.shape[0], -1) source_preds = discriminator(source_features) clf_loss = clf_criterion(source_preds, source_y) wasserstein_distance = critic(source_features).mean() - critic( target_features).mean() loss = clf_loss + args.wd_clf * wasserstein_distance clf_optim.zero_grad() loss.backward() clf_optim.step() mean_loss = total_loss / (args.iterations * args.k_critic) tqdm.write(f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}') torch.save(clf_model.state_dict(), out_file)
posterior.batch_size = opt.batch_size prior.batch_size = opt.batch_size opt.g_dim = tmp['opt'].g_dim opt.z_dim = tmp['opt'].z_dim opt.num_digits = tmp['opt'].num_digits # --------- transfer to gpu ------------------------------------ frame_predictor.cuda() posterior.cuda() prior.cuda() encoder.cuda() decoder.cuda() nets = [frame_predictor, posterior, prior, encoder, decoder] # ---------------- discriminator ---------- utils.set_requires_grad(nets, False) #-------- load DSVG elif opt.svg_name == 'DSVG': opt.model_path = 'logs/lp/smmnist-2/drsvg=dcgan64x64-n_past=5-n_future=10-z^p_dim=8-z^c_dim=64-last_frame_skip=False-beta=0.000100-random_seqs=True' tmp = torch.load('%s/opt.pth' % opt.model_path) import models.lstm as lstm_models if opt.model == 'dcgan': if opt.image_width == 64: import models.dcgan_64 as model elif opt.image_width == 128: import models.dcgan_128 as model elif opt.model == 'vgg': if opt.image_width == 64: import models.vgg_64 as model
def train(): # Fix Seed for Reproducibility # torch.manual_seed(9) if torch.cuda.is_available(): torch.cuda.manual_seed(9) # Samples, Weights and Results Path # paths = [config.samples_path, config.weights_path, config.plots_path] paths = [make_dirs(path) for path in paths] # Prepare Data Loader # train_horse_loader, train_zebra_loader = get_horse2zebra_loader('train', config.batch_size) val_horse_loader, val_zebra_loader = get_horse2zebra_loader('test', config.batch_size) total_batch = min(len(train_horse_loader), len(train_zebra_loader)) # Image Pool # masked_fake_A_pool = ImageMaskPool(config.pool_size) masked_fake_B_pool = ImageMaskPool(config.pool_size) # Prepare Networks # Attn_A = Attention() Attn_B = Attention() G_A2B = Generator() G_B2A = Generator() D_A = Discriminator() D_B = Discriminator() networks = [Attn_A, Attn_B, G_A2B, G_B2A, D_A, D_B] for network in networks: network.to(device) # Loss Function # criterion_Adversarial = nn.MSELoss() criterion_Cycle = nn.L1Loss() # Optimizers # D_optim = torch.optim.Adam(chain(D_A.parameters(), D_B.parameters()), lr=config.lr, betas=(0.5, 0.999)) G_optim = torch.optim.Adam(chain(Attn_A.parameters(), Attn_B.parameters(), G_A2B.parameters(), G_B2A.parameters()), lr=config.lr, betas=(0.5, 0.999)) D_optim_scheduler = get_lr_scheduler(D_optim) G_optim_scheduler = get_lr_scheduler(G_optim) # Lists # D_A_losses, D_B_losses = [], [] G_A_losses, G_B_losses = [], [] # Train # print("Training Unsupervised Attention-Guided GAN started with total epoch of {}.".format(config.num_epochs)) for epoch in range(config.num_epochs): for i, (real_A, real_B) in enumerate(zip(train_horse_loader, train_zebra_loader)): # Data Preparation # real_A = real_A.to(device) real_B = real_B.to(device) # Initialize Optimizers # D_optim.zero_grad() G_optim.zero_grad() ################### # Train Generator # ################### set_requires_grad([D_A, D_B], requires_grad=False) # Adversarial Loss using real A # attn_A = Attn_A(real_A) fake_B = G_A2B(real_A) masked_fake_B = fake_B * attn_A + real_A * (1-attn_A) masked_fake_B *= attn_A prob_real_A = D_A(masked_fake_B) real_labels = torch.ones(prob_real_A.size()).to(device) G_loss_A = criterion_Adversarial(prob_real_A, real_labels) # Adversarial Loss using real B # attn_B = Attn_B(real_B) fake_A = G_B2A(real_B) masked_fake_A = fake_A * attn_B + real_B * (1-attn_B) masked_fake_A *= attn_B prob_real_B = D_B(masked_fake_A) real_labels = torch.ones(prob_real_B.size()).to(device) G_loss_B = criterion_Adversarial(prob_real_B, real_labels) # Cycle Consistency Loss using real A # attn_ABA = Attn_B(masked_fake_B) fake_ABA = G_B2A(masked_fake_B) masked_fake_ABA = fake_ABA * attn_ABA + masked_fake_B * (1 - attn_ABA) # Cycle Consistency Loss using real B # attn_BAB = Attn_A(masked_fake_A) fake_BAB = G_A2B(masked_fake_A) masked_fake_BAB = fake_BAB * attn_BAB + masked_fake_A * (1 - attn_BAB) # Cycle Consistency Loss # G_cycle_loss_A = config.lambda_cycle * criterion_Cycle(masked_fake_ABA, real_A) G_cycle_loss_B = config.lambda_cycle * criterion_Cycle(masked_fake_BAB, real_B) # Total Generator Loss # G_loss = G_loss_A + G_loss_B + G_cycle_loss_A + G_cycle_loss_B # Back Propagation and Update # G_loss.backward() G_optim.step() ####################### # Train Discriminator # ####################### set_requires_grad([D_A, D_B], requires_grad=True) # Train Discriminator A using real A # prob_real_A = D_A(real_B) real_labels = torch.ones(prob_real_A.size()).to(device) D_loss_real_A = criterion_Adversarial(prob_real_A, real_labels) # Add Pooling # masked_fake_B, attn_A = masked_fake_B_pool.query(masked_fake_B, attn_A) masked_fake_B *= attn_A # Train Discriminator A using fake B # prob_fake_B = D_A(masked_fake_B.detach()) fake_labels = torch.zeros(prob_fake_B.size()).to(device) D_loss_fake_A = criterion_Adversarial(prob_fake_B, fake_labels) D_loss_A = (D_loss_real_A + D_loss_fake_A).mean() # Train Discriminator B using real B # prob_real_B = D_B(real_A) real_labels = torch.ones(prob_real_B.size()).to(device) D_loss_real_B = criterion_Adversarial(prob_real_B, real_labels) # Add Pooling # masked_fake_A, attn_B = masked_fake_A_pool.query(masked_fake_A, attn_B) masked_fake_A *= attn_B # Train Discriminator B using fake A # prob_fake_A = D_B(masked_fake_A.detach()) fake_labels = torch.zeros(prob_fake_A.size()).to(device) D_loss_fake_B = criterion_Adversarial(prob_fake_A, fake_labels) D_loss_B = (D_loss_real_B + D_loss_fake_B).mean() # Calculate Total Discriminator Loss # D_loss = D_loss_A + D_loss_B # Back Propagation and Update # D_loss.backward() D_optim.step() # Add items to Lists # D_A_losses.append(D_loss_A.item()) D_B_losses.append(D_loss_B.item()) G_A_losses.append(G_loss_A.item()) G_B_losses.append(G_loss_B.item()) #################### # Print Statistics # #################### if (i+1) % config.print_every == 0: print("UAG-GAN | Epoch [{}/{}] | Iteration [{}/{}] | D A Losses {:.4f} | D B Losses {:.4f} | G A Losses {:.4f} | G B Losses {:.4f}". format(epoch+1, config.num_epochs, i+1, total_batch, np.average(D_A_losses), np.average(D_B_losses), np.average(G_A_losses), np.average(G_B_losses))) # Save Sample Images # save_samples(val_horse_loader, val_zebra_loader, G_A2B, G_B2A, Attn_A, Attn_B, epoch, config.samples_path) # Adjust Learning Rate # D_optim_scheduler.step() G_optim_scheduler.step() # Save Model Weights # if (epoch + 1) % config.save_every == 0: torch.save(G_A2B.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Generator_A2B_Epoch_{}.pkl'.format(epoch+1))) torch.save(G_B2A.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Generator_B2A_Epoch_{}.pkl'.format(epoch+1))) torch.save(Attn_A.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Attention_A_Epoch_{}.pkl'.format(epoch+1))) torch.save(Attn_B.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Attention_B_Epoch_{}.pkl'.format(epoch+1))) # Make a GIF file # make_gifs_train("UAG-GAN", config.samples_path) # Plot Losses # plot_losses(D_A_losses, D_B_losses, G_A_losses, G_B_losses, config.num_epochs, config.plots_path) print("Training finished.")
lr=args.lr, weight_decay=args.weight_decay) sched = lr_scheduler.StepLR(optimizer, step_size=50) ''' setup tensorboard ''' writer = SummaryWriter(os.path.join(args.save_dir, 'train_info')) ''' load train and val features ''' print('===> load train and val features and labels ...') train_features, train_label, valid_features, valid_labels = utils.load_features( args) ''' train model ''' print('===> start training ...') iters = 0 best_acc = 0 for epoch in range(1, args.epoch + 1): FC.train() utils.set_requires_grad(FC, True) total_length = train_features.shape[0] perm_index = torch.randperm(total_length) train_X_sfl = train_features[perm_index] train_y_sfl = train_label[perm_index] # construct training batch for index in range(0, total_length, args.train_batch): train_info = 'Epoch: [{0}][{1}/{2}]'.format( epoch, index + 1, len(train_loader)) iters += 1 optimizer.zero_grad() if index + args.train_batch > total_length: #break input_X = train_X_sfl[index:] input_y = train_y_sfl[index:] else:
lr=args.lr, weight_decay=args.weight_decay) sched = lr_scheduler.StepLR(optimizer, step_size=50) ''' setup tensorboard ''' writer = SummaryWriter(os.path.join(args.save_dir, 'train_info')) ''' load train and val features ''' print('===> load train and val features and labels ...') train_features, train_labels, valid_features, valid_labels = utils.load_features( args) ''' train model ''' print('===> start training ...') iters = 0 best_acc = 0 for epoch in range(1, args.epoch + 1): BiRNN.train() utils.set_requires_grad(BiRNN, True) total_length = train_features.shape[0] # shuffle perm_index = np.random.permutation(len(train_features)) train_X_sfl = [train_features[i] for i in perm_index] train_y_sfl = np.array(train_labels)[perm_index] # construct training batch for index in range(0, total_length, args.train_batch): train_info = 'Epoch: [{0}][{1}/{2}]'.format( epoch, index + 1, len(train_loader)) iters += 1 optimizer.zero_grad() if index + args.train_batch > total_length: input_X = train_X_sfl[index:] input_y = train_y_sfl[index:] else:
def train(): # Fix Seed for Reproducibility # torch.manual_seed(9) if torch.cuda.is_available(): torch.cuda.manual_seed(9) # Samples, Weights and Results Path # paths = [config.samples_path, config.weights_path, config.plots_path] paths = [make_dirs(path) for path in paths] # Prepare Data Loader # train_horse_loader, train_zebra_loader = get_horse2zebra_loader( purpose='train', batch_size=config.batch_size) test_horse_loader, test_zebra_loader = get_horse2zebra_loader( purpose='test', batch_size=config.val_batch_size) total_batch = min(len(train_horse_loader), len(train_zebra_loader)) # Prepare Networks # D_A = Discriminator() D_B = Discriminator() G_A2B = Generator() G_B2A = Generator() networks = [D_A, D_B, G_A2B, G_B2A] for network in networks: network.to(device) # Loss Function # criterion_Adversarial = nn.MSELoss() criterion_Cycle = nn.L1Loss() criterion_Identity = nn.L1Loss() # Optimizers # D_A_optim = torch.optim.Adam(D_A.parameters(), lr=config.lr, betas=(0.5, 0.999)) D_B_optim = torch.optim.Adam(D_B.parameters(), lr=config.lr, betas=(0.5, 0.999)) G_optim = torch.optim.Adam(chain(G_A2B.parameters(), G_B2A.parameters()), lr=config.lr, betas=(0.5, 0.999)) D_A_optim_scheduler = get_lr_scheduler(D_A_optim) D_B_optim_scheduler = get_lr_scheduler(D_B_optim) G_optim_scheduler = get_lr_scheduler(G_optim) # Lists # D_losses_A, D_losses_B, G_losses = [], [], [] # Training # print("Training CycleGAN started with total epoch of {}.".format( config.num_epochs)) for epoch in range(config.num_epochs): for i, (horse, zebra) in enumerate(zip(train_horse_loader, train_zebra_loader)): # Data Preparation # real_A = horse.to(device) real_B = zebra.to(device) # Initialize Optimizers # G_optim.zero_grad() D_A_optim.zero_grad() D_B_optim.zero_grad() ################### # Train Generator # ################### set_requires_grad([D_A, D_B], requires_grad=False) # Adversarial Loss # fake_A = G_B2A(real_B) prob_fake_A = D_A(fake_A) real_labels = torch.ones(prob_fake_A.size()).to(device) G_mse_loss_B2A = criterion_Adversarial(prob_fake_A, real_labels) fake_B = G_A2B(real_A) prob_fake_B = D_B(fake_B) real_labels = torch.ones(prob_fake_B.size()).to(device) G_mse_loss_A2B = criterion_Adversarial(prob_fake_B, real_labels) # Identity Loss # identity_A = G_B2A(real_A) G_identity_loss_A = config.lambda_identity * criterion_Identity( identity_A, real_A) identity_B = G_A2B(real_B) G_identity_loss_B = config.lambda_identity * criterion_Identity( identity_B, real_B) # Cycle Loss # reconstructed_A = G_B2A(fake_B) G_cycle_loss_ABA = config.lambda_cycle * criterion_Cycle( reconstructed_A, real_A) reconstructed_B = G_A2B(fake_A) G_cycle_loss_BAB = config.lambda_cycle * criterion_Cycle( reconstructed_B, real_B) # Calculate Total Generator Loss # G_loss = G_mse_loss_B2A + G_mse_loss_A2B + G_identity_loss_A + G_identity_loss_B + G_cycle_loss_ABA + G_cycle_loss_BAB # Back Propagation and Update # G_loss.backward(retain_graph=True) G_optim.step() ####################### # Train Discriminator # ####################### set_requires_grad([D_A, D_B], requires_grad=True) ## Train Discriminator A ## # Real Loss # prob_real_A = D_A(real_A) real_labels = torch.ones(prob_real_A.size()).to(device) D_real_loss_A = criterion_Adversarial(prob_real_A, real_labels) # Fake Loss # prob_fake_A = D_A(fake_A.detach()) fake_labels = torch.zeros(prob_fake_A.size()).to(device) D_fake_loss_A = criterion_Adversarial(prob_fake_A, fake_labels) # Calculate Total Discriminator A Loss # D_loss_A = config.lambda_identity * (D_real_loss_A + D_fake_loss_A).mean() # Back propagation and Update # D_loss_A.backward(retain_graph=True) D_A_optim.step() ## Train Discriminator B ## # Real Loss # prob_real_B = D_B(real_B) real_labels = torch.ones(prob_real_B.size()).to(device) loss_real_B = criterion_Adversarial(prob_real_B, real_labels) # Fake Loss # prob_fake_B = D_B(fake_B.detach()) fake_labels = torch.zeros(prob_fake_B.size()).to(device) loss_fake_B = criterion_Adversarial(prob_fake_B, fake_labels) # Calculate Total Discriminator B Loss # D_loss_B = config.lambda_identity * (loss_real_B + loss_fake_B).mean() # Back propagation and Update # D_loss_B.backward(retain_graph=True) D_B_optim.step() # Add items to Lists # D_losses_A.append(D_loss_A.item()) D_losses_B.append(D_loss_B.item()) G_losses.append(G_loss.item()) #################### # Print Statistics # #################### if (i + 1) % config.print_every == 0: print( "CycleGAN | Epoch [{}/{}] | Iterations [{}/{}] | D_A Loss {:.4f} | D_B Loss {:.4f} | G Loss {:.4f}" .format(epoch + 1, config.num_epochs, i + 1, total_batch, np.average(D_losses_A), np.average(D_losses_B), np.average(G_losses))) # Save Sample Images # sample_images(test_horse_loader, test_zebra_loader, G_A2B, G_B2A, epoch, config.samples_path) # Adjust Learning Rate # D_A_optim_scheduler.step() D_B_optim_scheduler.step() G_optim_scheduler.step() # Save Model Weights # if (epoch + 1) % config.save_every == 0: torch.save( G_A2B.state_dict(), os.path.join( config.weights_path, 'CycleGAN_Generator_A2B_Epoch_{}.pkl'.format(epoch + 1))) torch.save( G_B2A.state_dict(), os.path.join( config.weights_path, 'CycleGAN_Generator_B2A_Epoch_{}.pkl'.format(epoch + 1))) # Make a GIF file # make_gifs_train("CycleGAN", config.samples_path) # Plot Losses # plot_losses(D_losses_A, D_losses_B, G_losses, config.num_epochs, config.plots_path) print("Training finished.")
def main(args): if args.model == 'gta': model_file = './trained_models/gta_source.pt' out_file = './trained_models/gta_adda.pt' out_ftrs = 4375 model = GTANet().to(device) model.load_state_dict(torch.load(model_file)) model.eval() set_requires_grad(model, False) source_model = model.feature_extractor clf = model model_2 = GTANet().to(device) model_2.load_state_dict(torch.load(model_file)) target_model = model_2.feature_extractor elif args.model == 'gta-res': model_file = './trained_models/gta_res_source.pt' out_file = './trained_models/gta_res_adda.pt' model = GTARes18Net(9, pretrained=False).to(device) out_ftrs = model.fc.in_features model.load_state_dict(torch.load(model_file)) model.eval() set_requires_grad(model, False) source_model = model.feature_extractor clf = model model_2 = GTARes18Net(9, pretrained=False).to(device) model_2.load_state_dict(torch.load(model_file)) target_model = model_2.feature_extractor elif args.model == 'gta-vgg': model_file = './trained_models/gta_vgg_source.pt' out_file = './trained_models/gta_vgg_adda.pt' model = GTAVGG11Net(9, pretrained=False).to(device) out_ftrs = model.classifier[0].in_features # should be 512 * 7 * 7 model.load_state_dict(torch.load(model_file)) model.eval() set_requires_grad(model, False) def source_model(x): x = model.features(x) x = model.avgpool(x) x = torch.flatten(x, 1) return x clf = model model_2 = GTAVGG11Net(9, pretrained=False).to(device) model_2.load_state_dict(torch.load(model_file)) def target_model(x): x = model_2.features(x) x = model_2.avgpool(x) x = torch.flatten(x, 1) return x else: raise ValueError(f'Unknown model type {args.model}') discriminator = nn.Sequential( nn.Linear(out_ftrs, 64), nn.ReLU(), nn.Linear(64, 16), nn.ReLU(), nn.Linear(16, 1), ).to(device) half_batch = args.batch_size // 2 target_dataset = ImageFolder('./data', transform=Compose([ Resize((398, 224)), RandomCrop(224), RandomHorizontalFlip(), ToTensor(), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ])) target_loader = DataLoader(target_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) source_dataset = ImageFolder('./t_data', transform=Compose([ RandomCrop(224, pad_if_needed=True, padding_mode='reflect'), RandomHorizontalFlip(), ToTensor(), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ])) source_loader = DataLoader(source_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) discriminator_optim = torch.optim.Adam(discriminator.parameters()) target_optim = torch.optim.Adam(model_2.parameters()) criterion = nn.BCEWithLogitsLoss() for epoch in range(1, args.epochs + 1): batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) total_loss = 0 total_accuracy = 0 for _ in trange(args.iterations, leave=False): # Train discriminator set_requires_grad(model_2, requires_grad=False) set_requires_grad(discriminator, requires_grad=True) for _ in range(args.k_disc): (source_x, _), (target_x, _) = next(batch_iterator) source_x, target_x = source_x.to(device), target_x.to(device) source_features = source_model(source_x).view( source_x.shape[0], -1) target_features = target_model(target_x).view( target_x.shape[0], -1) discriminator_x = torch.cat([source_features, target_features]) discriminator_y = torch.cat([ torch.ones(source_x.shape[0], device=device), torch.zeros(target_x.shape[0], device=device) ]) preds = discriminator(discriminator_x).squeeze() loss = criterion(preds, discriminator_y) discriminator_optim.zero_grad() loss.backward() discriminator_optim.step() total_loss += loss.item() total_accuracy += ((preds > 0).long() == discriminator_y.long() ).float().mean().item() # Train classifier set_requires_grad(model_2, requires_grad=True) set_requires_grad(discriminator, requires_grad=False) for _ in range(args.k_clf): _, (target_x, _) = next(batch_iterator) target_x = target_x.to(device) target_features = target_model(target_x).view( target_x.shape[0], -1) # flipped labels discriminator_y = torch.ones(target_x.shape[0], device=device) preds = discriminator(target_features).squeeze() loss = criterion(preds, discriminator_y) target_optim.zero_grad() loss.backward() target_optim.step() mean_loss = total_loss / (args.iterations * args.k_disc) mean_accuracy = total_accuracy / (args.iterations * args.k_disc) tqdm.write(f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, ' f'discriminator_accuracy={mean_accuracy:.4f}') # Create the full target model and save it if args.model == 'gta': clf.feature_extractor = target_model elif args.model == 'gta-res': clf.conv1 = model_2.conv1 clf.bn1 = model_2.bn1 clf.relu = model_2.relu clf.maxpool = model_2.maxpool clf.layer1 = model_2.layer1 clf.layer2 = model_2.layer2 clf.layer3 = model_2.layer3 clf.layer4 = model_2.layer4 clf.avgpool = model_2.avgpool torch.save(clf.state_dict(), out_file)
def main(): # Get training options opt = get_opt() # Define the networks # netG_A: used to transfer image from domain A to domain B # netG_B: used to transfer image from domain B to domain A netG_A = networks.Generator(opt.input_nc, opt.output_nc, opt.ngf, opt.n_res, opt.dropout) netG_B = networks.Generator(opt.output_nc, opt.input_nc, opt.ngf, opt.n_res, opt.dropout) if opt.u_net: netG_A = networks.U_net(opt.input_nc, opt.output_nc, opt.ngf) netG_B = networks.U_net(opt.output_nc, opt.input_nc, opt.ngf) # netD_A: used to test whether an image is from domain B # netD_B: used to test whether an image is from domain A netD_A = networks.Discriminator(opt.input_nc, opt.ndf) netD_B = networks.Discriminator(opt.output_nc, opt.ndf) # Initialize the networks if opt.cuda: netG_A.cuda() netG_B.cuda() netD_A.cuda() netD_B.cuda() utils.init_weight(netG_A) utils.init_weight(netG_B) utils.init_weight(netD_A) utils.init_weight(netD_B) if opt.pretrained: netG_A.load_state_dict(torch.load('pretrained/netG_A.pth')) netG_B.load_state_dict(torch.load('pretrained/netG_B.pth')) netD_A.load_state_dict(torch.load('pretrained/netD_A.pth')) netD_B.load_state_dict(torch.load('pretrained/netD_B.pth')) # Define the loss functions criterion_GAN = utils.GANLoss() if opt.cuda: criterion_GAN.cuda() criterion_cycle = torch.nn.L1Loss() # Alternatively, can try MSE cycle consistency loss #criterion_cycle = torch.nn.MSELoss() criterion_identity = torch.nn.L1Loss() # Define the optimizers optimizer_G = torch.optim.Adam(itertools.chain(netG_A.parameters(), netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) # Create learning rate schedulers lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( optimizer_G, lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs, opt.n_epochs_decay).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( optimizer_D_A, lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs, opt.n_epochs_decay).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( optimizer_D_B, lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs, opt.n_epochs_decay).step) Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor input_A = Tensor(opt.batch_size, opt.input_nc, opt.sizeh, opt.sizew) input_B = Tensor(opt.batch_size, opt.output_nc, opt.sizeh, opt.sizew) # Define two image pools to store generated images fake_A_pool = utils.ImagePool() fake_B_pool = utils.ImagePool() # Define the transform, and load the data transform = transforms.Compose([ transforms.Resize((opt.sizeh, opt.sizew)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ]) dataloader = DataLoader(ImageDataset(opt.rootdir, transform=transform, mode='train'), batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu) # numpy arrays to store the loss of epoch loss_G_array = np.zeros(opt.n_epochs + opt.n_epochs_decay) loss_D_A_array = np.zeros(opt.n_epochs + opt.n_epochs_decay) loss_D_B_array = np.zeros(opt.n_epochs + opt.n_epochs_decay) # Training for epoch in range(opt.epoch, opt.n_epochs + opt.n_epochs_decay): start = time.strftime("%H:%M:%S") print("current epoch :", epoch, " start time :", start) # Empty list to store the loss of each mini-batch loss_G_list = [] loss_D_A_list = [] loss_D_B_list = [] for i, batch in enumerate(dataloader): if i % 50 == 1: print("current step: ", i) current = time.strftime("%H:%M:%S") print("current time :", current) print("last loss G:", loss_G_list[-1], "last loss D_A", loss_D_A_list[-1], "last loss D_B", loss_D_B_list[-1]) real_A = input_A.copy_(batch['A']) real_B = input_B.copy_(batch['B']) # Train the generator optimizer_G.zero_grad() # Compute fake images and reconstructed images fake_B = netG_A(real_A) fake_A = netG_B(real_B) if opt.identity_loss != 0: same_B = netG_A(real_B) same_A = netG_B(real_A) # discriminators require no gradients when optimizing generators utils.set_requires_grad([netD_A, netD_B], False) # Identity loss if opt.identity_loss != 0: loss_identity_A = criterion_identity( same_A, real_A) * opt.identity_loss loss_identity_B = criterion_identity( same_B, real_B) * opt.identity_loss # GAN loss prediction_fake_B = netD_B(fake_B) loss_gan_B = criterion_GAN(prediction_fake_B, True) prediction_fake_A = netD_A(fake_A) loss_gan_A = criterion_GAN(prediction_fake_A, True) # Cycle consistent loss recA = netG_B(fake_B) recB = netG_A(fake_A) loss_cycle_A = criterion_cycle(recA, real_A) * opt.cycle_loss loss_cycle_B = criterion_cycle(recB, real_B) * opt.cycle_loss # total loss without the identity loss loss_G = loss_gan_B + loss_gan_A + loss_cycle_A + loss_cycle_B if opt.identity_loss != 0: loss_G += loss_identity_A + loss_identity_B loss_G_list.append(loss_G.item()) loss_G.backward() optimizer_G.step() # Train the discriminator utils.set_requires_grad([netD_A, netD_B], True) # Train the discriminator D_A optimizer_D_A.zero_grad() # real images pred_real = netD_A(real_A) loss_D_real = criterion_GAN(pred_real, True) # fake images fake_A = fake_A_pool.query(fake_A) pred_fake = netD_A(fake_A.detach()) loss_D_fake = criterion_GAN(pred_fake, False) #total loss loss_D_A = (loss_D_real + loss_D_fake) * 0.5 loss_D_A_list.append(loss_D_A.item()) loss_D_A.backward() optimizer_D_A.step() # Train the discriminator D_B optimizer_D_B.zero_grad() # real images pred_real = netD_B(real_B) loss_D_real = criterion_GAN(pred_real, True) # fake images fake_B = fake_B_pool.query(fake_B) pred_fake = netD_B(fake_B.detach()) loss_D_fake = criterion_GAN(pred_fake, False) # total loss loss_D_B = (loss_D_real + loss_D_fake) * 0.5 loss_D_B_list.append(loss_D_B.item()) loss_D_B.backward() optimizer_D_B.step() # Update the learning rate lr_scheduler_G.step() lr_scheduler_D_A.step() lr_scheduler_D_B.step() # Save models checkpoints torch.save(netG_A.state_dict(), 'model/netG_A.pth') torch.save(netG_B.state_dict(), 'model/netG_B.pth') torch.save(netD_A.state_dict(), 'model/netD_A.pth') torch.save(netD_B.state_dict(), 'model/netD_B.pth') # Save other checkpoint information checkpoint = { 'epoch': epoch, 'optimizer_G': optimizer_G.state_dict(), 'optimizer_D_A': optimizer_D_A.state_dict(), 'optimizer_D_B': optimizer_D_B.state_dict(), 'lr_scheduler_G': lr_scheduler_G.state_dict(), 'lr_scheduler_D_A': lr_scheduler_D_A.state_dict(), 'lr_scheduler_D_B': lr_scheduler_D_B.state_dict() } torch.save(checkpoint, 'model/checkpoint.pth') # Update the numpy arrays that record the loss loss_G_array[epoch] = sum(loss_G_list) / len(loss_G_list) loss_D_A_array[epoch] = sum(loss_D_A_list) / len(loss_D_A_list) loss_D_B_array[epoch] = sum(loss_D_B_list) / len(loss_D_B_list) np.savetxt('model/loss_G.txt', loss_G_array) np.savetxt('model/loss_D_A.txt', loss_D_A_array) np.savetxt('model/loss_D_b.txt', loss_D_B_array) if epoch % 10 == 9: torch.save(netG_A.state_dict(), 'model/netG_A' + str(epoch) + '.pth') torch.save(netG_B.state_dict(), 'model/netG_B' + str(epoch) + '.pth') torch.save(netD_A.state_dict(), 'model/netD_A' + str(epoch) + '.pth') torch.save(netD_B.state_dict(), 'model/netD_B' + str(epoch) + '.pth') end = time.strftime("%H:%M:%S") print("current epoch :", epoch, " end time :", end) print("G loss :", loss_G_array[epoch], "D_A loss :", loss_D_A_array[epoch], "D_B loss :", loss_D_B_array[epoch])
def main(args): source_model = Net().to(device) source_model.load_state_dict(torch.load(args.MODEL_FILE)) source_model.eval() set_requires_grad(source_model, requires_grad=False) clf = source_model source_model = source_model.feature_extractor target_model = Net().to(device) target_model.load_state_dict(torch.load(args.MODEL_FILE)) target_model = target_model.feature_extractor classifier = clf.classifier discriminator = nn.Sequential(nn.Linear(EXTRACTED_FEATURE_DIM, 64), nn.ReLU(), nn.BatchNorm1d(64), nn.Linear(64, 1), nn.Sigmoid()).to(device) #half_batch = args.batch_size // 2 batch_size = args.batch_size # X_source, y_source = preprocess_train() X_source, y_source = preprocess_train_single(1) source_dataset = torch.utils.data.TensorDataset(X_source, y_source) source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True) X_target, y_target = preprocess_test(args.person) target_dataset = torch.utils.data.TensorDataset(X_target, y_target) target_loader = DataLoader(target_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True) discriminator_optim = torch.optim.Adam(discriminator.parameters()) target_optim = torch.optim.Adam(target_model.parameters(), lr=3e-6) criterion = nn.BCEWithLogitsLoss() criterion_class = nn.CrossEntropyLoss() best_tar_acc = test(args, clf) final_accs = [] for epoch in range(1, args.epochs + 1): source_loader = DataLoader(source_loader.dataset, batch_size=batch_size, shuffle=True) target_loader = DataLoader(target_loader.dataset, batch_size=batch_size, shuffle=True) batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) total_loss = 0 adv_loss = 0 total_accuracy = 0 second_acc = 0 total_class_loss = 0 for _ in trange(args.iterations, leave=False): # Train discriminator set_requires_grad(target_model, requires_grad=False) set_requires_grad(discriminator, requires_grad=True) discriminator.train() for _ in range(args.k_disc): (source_x, source_y), (target_x, _) = next(batch_iterator) source_y = source_y.to(device).view(-1) source_x, target_x = source_x.to(device), target_x.to(device) source_features = source_model(source_x).view( source_x.shape[0], -1) target_features = target_model(target_x).view( target_x.shape[0], -1) discriminator_x = torch.cat([source_features, target_features]) discriminator_y = torch.cat([ torch.ones(source_x.shape[0], device=device), torch.zeros(target_x.shape[0], device=device) ]) preds = discriminator(discriminator_x).squeeze() loss = criterion(preds, discriminator_y) discriminator_optim.zero_grad() loss.backward() discriminator_optim.step() total_loss += loss.item() total_accuracy += ((preds >= 0.5).long() == discriminator_y. long()).float().mean().item() # Train feature extractor set_requires_grad(target_model, requires_grad=True) set_requires_grad(discriminator, requires_grad=False) target_model.train() for _ in range(args.k_clf): _, (target_x, _) = next(batch_iterator) target_x = target_x.to(device) target_features = target_model(target_x).view( target_x.shape[0], -1) source_features = target_model(source_x).view( source_x.shape[0], -1) source_pred = classifier(source_features) # (batch_size, 4) # flipped labels discriminator_y = torch.ones(target_x.shape[0], device=device) preds = discriminator(target_features).squeeze() second_acc += ((preds >= 0.5).long() == discriminator_y.long() ).float().mean().item() loss_adv = criterion(preds, discriminator_y) adv_loss += loss_adv.item() loss_class = criterion_class(source_pred, source_y) total_class_loss += loss_class.item() loss = loss_adv #+ 0.001*loss_class target_optim.zero_grad() loss.backward() target_optim.step() mean_loss = total_loss / (args.iterations * args.k_disc) mean_adv_loss = adv_loss / (args.iterations * args.k_clf) total_class_loss = total_class_loss / (args.iterations * args.k_clf) dis_accuracy = total_accuracy / (args.iterations * args.k_disc) sec_acc = second_acc / (args.iterations * args.k_clf) clf.feature_extractor = target_model tar_accuarcy = test(args, clf) final_accs.append(tar_accuarcy) if tar_accuarcy > best_tar_acc: best_tar_acc = tar_accuarcy torch.save(clf.state_dict(), 'trained_models/adda.pt') tqdm.write( f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, adv_loss = {mean_adv_loss:.4f}, ' f'discriminator_accuracy={dis_accuracy:.4f}, tar_accuary = {tar_accuarcy:.4f}, best_accuracy = {best_tar_acc:.4f}, ' f'sec_acc = {sec_acc:.4f}, total_class_loss: {total_class_loss:.4f}' ) # Create the full target model and save it clf.feature_extractor = target_model #torch.save(clf.state_dict(), 'trained_models/adda.pt') jd = {"test_acc": final_accs} with open(str(args.seed) + '/acc' + str(args.person) + '.json', 'w') as f: json.dump(jd, f)
def _do_epoch(self): set_mode(self.main_model, "train") set_mode(self.dis_model, "train") set_mode(self.c_model, "train") set_mode(self.cp_model, "train") set_lambda([self.dis_model], [self.args.lbd_d]) set_lambda([self.c_model, self.cp_model], [self.args.lbd_c, self.args.lbd_cp]) loader_iter_list = [] loader_size_list = [] if self.current_epoch < self.args.warmup_step: aux_weight = self.args.warmup_weight main_weight = self.args.warmup_weight else: aux_weight = 1 main_weight = 1 for loader in self.source_loader_list: loader_iter_list.append(enumerate(loader)) loader_size_list.append(len(loader)) for it in range(max(loader_size_list)): data = [] labels = [] domains = [] for idx, iter_ in zip(range(self.num_domains), loader_iter_list): try: item = iter_.__next__() except StopIteration: loader_iter_list[idx] = enumerate( self.source_loader_list[idx]) item = loader_iter_list[idx].__next__() data.append(item[1][0]) labels.append(item[1][1]) domains.append(torch.ones(labels[-1].size(0)).long() * idx) data = torch.cat(data, dim=0).to(self.device) labels = torch.cat(labels, dim=0).to(self.device) domains = torch.cat(domains, dim=0).to(self.device) set_requires_grad(self.main_model, False) set_requires_grad(self.c_model, True) _, feature = self.main_model(data) c_loss_self = self._compute_cls_loss( self.c_model, feature.detach(), labels, domains, mode="self") * aux_weight self.optimizer.zero_grad() c_loss_self.backward() self.optimizer.step() set_requires_grad( [self.main_model, self.dis_model, self.c_model, self.cp_model], True) class_logit, feature = self.main_model(data) main_loss = F.cross_entropy(class_logit, labels) * main_weight dis_loss = self._compute_dis_loss(feature, domains) * aux_weight set_requires_grad(self.c_model, False) c_loss_others = self._compute_cls_loss( self.c_model, feature, labels, domains, mode="others") * aux_weight cp_loss = self._compute_cls_loss( self.cp_model, feature, labels, domains, mode="self") * aux_weight loss = dis_loss + c_loss_others + cp_loss + main_loss self.optimizer.zero_grad() loss.backward() self.optimizer.step() loss += c_loss_self message = "epoch %d iter %d: all %.6f main %.6f dis %.6f c_self %.6f c_others %.6f cp %.6f\n" % ( self.current_epoch, it, loss.data, main_loss.data, dis_loss.data, c_loss_self.data, c_loss_others.data, cp_loss.data) with open(self.log_file, "a") as fid: fid.write(message) print(message) del loss, main_loss, dis_loss, c_loss_self, c_loss_others, cp_loss self.main_model.eval() with torch.no_grad(): with open(self.log_file, "a") as fid: for phase, loader in self.test_loaders.items(): class_correct, all_domains = self.do_test(loader) class_correct = class_correct.float() class_acc = class_correct.mean() * 100.0 self.results[phase][self.current_epoch] = class_acc if phase == "val": message = "epoch %d: val_all_acc %.5f" % ( self.current_epoch, class_acc) for i in range(self.num_domains): cc_i = class_correct[all_domains == i] ca_i = cc_i.mean() * 100.0 message += " val_%s_acc %.5f" % ( self.args.source[i], ca_i) message += "\n" fid.write(message) print(message) elif phase == "test": message = "epoch %d: test_acc %.5f\n" % ( self.current_epoch, class_acc) fid.write(message) print(message)
parser.add_argument("--model-path", type=str, required=True) parser.add_argument("--config-path", type=str, required=True) parser.add_argument("--save-dir-path", type=str, default=".") parser.add_argument("--tol", type=float, default=0) parser.add_argument("--batch-size", type=int, default=512) parser.add_argument("--distance", default="l1", choices=["l1", "l2"]) args = parser.parse_args() cfg, G, lidar, device = utils.setup( args.model_path, args.config_path, ema=True, fix_noise=True, ) utils.set_requires_grad(G, False) G = DP(G) # hyperparameters num_step = 1000 perturb_latent = True noise_ratio = 0.75 noise_sigma = 1.0 lr_rampup_ratio = 0.05 lr_rampdown_ratio = 0.25 # prepare reference dataset = define_dataset(cfg.dataset, phase="test") loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size,
def main(): # Get training options opt = get_opt() device = torch.device("cuda") if opt.cuda else torch.device("cpu") # Define the networks # netG_A: used to transfer image from domain A to domain B netG_A = networks.Generator(opt.input_nc, opt.output_nc, opt.ngf, opt.n_res, opt.dropout) if opt.u_net: netG_A = networks.U_net(opt.input_nc, opt.output_nc, opt.ngf) # netD_B: used to test whether an image is from domain A netD_B = networks.Discriminator(opt.input_nc + opt.output_nc, opt.ndf) # Initialize the networks if opt.cuda: netG_A.cuda() netD_B.cuda() utils.init_weight(netG_A) utils.init_weight(netD_B) if opt.pretrained: netG_A.load_state_dict(torch.load('pretrained/netG_A.pth')) netD_B.load_state_dict(torch.load('pretrained/netD_B.pth')) # Define the loss functions criterion_GAN = utils.GANLoss() if opt.cuda: criterion_GAN.cuda() criterion_l1 = torch.nn.L1Loss() # Define the optimizers optimizer_G = torch.optim.Adam(netG_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) # Create learning rate schedulers lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda = utils.Lambda_rule(opt.epoch, opt.n_epochs, opt.n_epochs_decay).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda = utils.Lambda_rule(opt.epoch, opt.n_epochs, opt.n_epochs_decay).step) # Define the transform, and load the data transform = transforms.Compose([transforms.Resize((opt.sizeh, opt.sizew)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) dataloader = DataLoader(PairedImage(opt.rootdir, transform = transform, mode = 'train'), batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu) # numpy arrays to store the loss of epoch loss_G_array = np.zeros(opt.n_epochs + opt.n_epochs_decay) loss_D_B_array = np.zeros(opt.n_epochs + opt.n_epochs_decay) # Training for epoch in range(opt.epoch, opt.n_epochs + opt.n_epochs_decay): start = time.strftime("%H:%M:%S") print("current epoch :", epoch, " start time :", start) # Empty list to store the loss of each mini-batch loss_G_list = [] loss_D_B_list = [] for i, batch in enumerate(dataloader): if i % 20 == 1: print("current step: ", i) current = time.strftime("%H:%M:%S") print("current time :", current) print("last loss G_A:", loss_G_list[-1], "last loss D_B:", loss_D_B_list[-1]) real_A = batch['A'].to(device) real_B = batch['B'].to(device) # Train the generator utils.set_requires_grad([netG_A], True) optimizer_G.zero_grad() # Compute fake images and reconstructed images fake_B = netG_A(real_A) # discriminators require no gradients when optimizing generators utils.set_requires_grad([netD_B], False) # GAN loss prediction_fake_B = netD_B(torch.cat((fake_B, real_A), dim=1)) loss_gan = criterion_GAN(prediction_fake_B, True) #L1 loss loss_l1 = criterion_l1(real_B, fake_B) * opt.l1_loss # total loss without the identity loss loss_G = loss_gan + loss_l1 loss_G_list.append(loss_G.item()) loss_G.backward() optimizer_G.step() # Train the discriminator utils.set_requires_grad([netG_A], False) utils.set_requires_grad([netD_B], True) # Train the discriminator D_B optimizer_D_B.zero_grad() # real images pred_real = netD_B(torch.cat((real_B, real_A), dim=1)) loss_D_real = criterion_GAN(pred_real, True) # fake images fake_B = netG_A(real_A) pred_fake = netD_B(torch.cat((fake_B, real_A), dim=1)) loss_D_fake = criterion_GAN(pred_fake, False) # total loss loss_D_B = (loss_D_real + loss_D_fake) * 0.5 loss_D_B_list.append(loss_D_B.item()) loss_D_B.backward() optimizer_D_B.step() # Update the learning rate lr_scheduler_G.step() lr_scheduler_D_B.step() # Save models checkpoints torch.save(netG_A.state_dict(), 'model/netG_A_pix.pth') torch.save(netD_B.state_dict(), 'model/netD_B_pix.pth') # Save other checkpoint information checkpoint = {'epoch': epoch, 'optimizer_G': optimizer_G.state_dict(), 'optimizer_D_B': optimizer_D_B.state_dict(), 'lr_scheduler_G': lr_scheduler_G.state_dict(), 'lr_scheduler_D_B': lr_scheduler_D_B.state_dict()} torch.save(checkpoint, 'model/checkpoint.pth') # Update the numpy arrays that record the loss loss_G_array[epoch] = sum(loss_G_list) / len(loss_G_list) loss_D_B_array[epoch] = sum(loss_D_B_list) / len(loss_D_B_list) np.savetxt('model/loss_G.txt', loss_G_array) np.savetxt('model/loss_D_B.txt', loss_D_B_array) end = time.strftime("%H:%M:%S") print("current epoch :", epoch, " end time :", end) print("G loss :", loss_G_array[epoch], "D_B loss :", loss_D_B_array[epoch])
def main(args): final_accs = [] source_models = [Net().to(device) for _ in range(10)] for idx in range(len(source_models)): source_models[idx].load_state_dict(torch.load(args.MODEL_FILE)) source_models[idx].eval() set_requires_grad(source_models[idx], requires_grad=False) clfs = [source_model for source_model in source_models] source_models = [ source_model.feature_extractor for source_model in source_models ] target_models = [Net().to(device) for _ in range(10)] for idx in range(len(target_models)): target_models[idx].load_state_dict(torch.load(args.MODEL_FILE)) target_models[idx] = target_models[idx].feature_extractor discriminators = [ nn.Sequential(nn.Linear(EXTRACTED_FEATURE_DIM, 64), nn.ReLU(), nn.BatchNorm1d(64), nn.Linear(64, 1), nn.Sigmoid()).to(device) for _ in range(10) ] batch_size = args.batch_size discriminator_optims = [ torch.optim.Adam(discriminators[idx].parameters(), lr=1e-5) for idx in range(10) ] target_optims = [ torch.optim.Adam(target_models[idx].parameters(), lr=1e-5) for idx in range(10) ] criterion = nn.BCEWithLogitsLoss() source_loaders = [] target_loaders = [] for idx in range(10): X_source, y_source = preprocess_train_single(idx) source_dataset = torch.utils.data.TensorDataset(X_source, y_source) source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True) source_loaders.append(source_loader) X_target, y_target = preprocess_test(args.person) target_dataset = torch.utils.data.TensorDataset(X_target, y_target) target_loader = DataLoader(target_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True) target_loaders.append(target_loader) best_voting_acc = test_all(clfs) best_tar_accs = [0.0] * 10 for epoch in range(1, args.epochs + 1): source_loaders = [ DataLoader(source_loaders[idx].dataset, batch_size=batch_size, shuffle=True) for idx in range(10) ] target_loaders = [ DataLoader(target_loaders[idx].dataset, batch_size=batch_size, shuffle=True) for idx in range(10) ] for idx in range(10): source_loader = source_loaders[idx] target_loader = target_loaders[idx] batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) target_model = target_models[idx] discriminator = discriminators[idx] source_model = source_models[idx] clf = clfs[idx] total_loss = 0 adv_loss = 0 total_accuracy = 0 second_acc = 0 for _ in trange(args.iterations, leave=False): # Train discriminator set_requires_grad(target_model, requires_grad=False) set_requires_grad(discriminator, requires_grad=True) discriminator.train() for _ in range(args.k_disc): (source_x, _), (target_x, _) = next(batch_iterator) source_x, target_x = source_x.to(device), target_x.to( device) source_features = source_model(source_x).view( source_x.shape[0], -1) target_features = target_model(target_x).view( target_x.shape[0], -1) discriminator_x = torch.cat( [source_features, target_features]) discriminator_y = torch.cat([ torch.ones(source_x.shape[0], device=device), torch.zeros(target_x.shape[0], device=device) ]) preds = discriminator(discriminator_x).squeeze() loss = criterion(preds, discriminator_y) discriminator_optims[idx].zero_grad() loss.backward() discriminator_optims[idx].step() total_loss += loss.item() total_accuracy += ((preds >= 0.5).long( ) == discriminator_y.long()).float().mean().item() # Train classifier set_requires_grad(target_model, requires_grad=True) set_requires_grad(discriminator, requires_grad=False) target_model.train() for _ in range(args.k_clf): _, (target_x, _) = next(batch_iterator) target_x = target_x.to(device) target_features = target_model(target_x).view( target_x.shape[0], -1) # flipped labels discriminator_y = torch.ones(target_x.shape[0], device=device) preds = discriminator(target_features).squeeze() second_acc += ((preds >= 0.5).long() == discriminator_y. long()).float().mean().item() loss = criterion(preds, discriminator_y) adv_loss += loss.item() target_optims[idx].zero_grad() loss.backward() target_optims[idx].step() mean_loss = total_loss / (args.iterations * args.k_disc) mean_adv_loss = adv_loss / (args.iterations * args.k_clf) dis_accuracy = total_accuracy / (args.iterations * args.k_disc) sec_acc = second_acc / (args.iterations * args.k_clf) clf.feature_extractor = target_model tar_accuarcy = test(args, clf) if tar_accuarcy > best_tar_accs[idx]: best_tar_accs[idx] = tar_accuarcy torch.save(clf.state_dict(), 'trained_models/adda' + str(idx) + '.pt') tqdm.write( f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, adv_loss = {mean_adv_loss:.4f}, ' f'discriminator_accuracy={dis_accuracy:.4f}, tar_accuary = {tar_accuarcy:.4f}, best_accuracy = {best_tar_accs[idx]:.4f}, sec_acc = {sec_acc:.4f}' ) # Create the full target model and save it clf.feature_extractor = target_model #torch.save(clf.state_dict(), 'trained_models/adda.pt') acc = test_all(clfs) final_accs.append(acc) if acc > best_voting_acc: best_voting_acc = acc print("In epoch %d, voting_acc: %.4f, best_voting_acc: %.4f" % (epoch, acc, best_voting_acc)) jd = {"test_acc": final_accs} with open(str(args.seed) + '/acc' + str(args.person) + '.json', 'w') as f: json.dump(jd, f)
def train_models(self): self._load_checkpoint() # loop over the dataset multiple times for self.epoch_id in range(self.epoch_to_start, self.max_num_epochs): ################## train ################# ########################################## self._clear_cache() self.is_training = True self.net_G.train() # Set model to training mode self.net_D1.train() # Set model to training mode self.net_D2.train() # Set model to training mode self.net_D3.train() # Set model to training mode # Iterate over data. for self.batch_id, batch in enumerate(self.dataloaders['train'], 0): self._forward_pass(batch) # update D1 and D2 and D3 utils.set_requires_grad(self.net_D1, True) utils.set_requires_grad(self.net_D2, True) utils.set_requires_grad(self.net_D3, True) self.optimizer_D1.zero_grad() self.optimizer_D2.zero_grad() self.optimizer_D3.zero_grad() self._backward_D() self.optimizer_D1.step() self.optimizer_D2.step() self.optimizer_D3.step() # update G utils.set_requires_grad(self.net_D1, False) utils.set_requires_grad(self.net_D2, False) utils.set_requires_grad(self.net_D3, False) self.optimizer_G.zero_grad() self._backward_G() self.optimizer_G.step() self._collect_running_batch_states() self._collect_epoch_states() self._update_lr_schedulers() ################## Eval ################## ########################################## print('Begin evaluation...') self._clear_cache() self.is_training = False # Set model to evaluate mode self.net_G.eval() self.net_D1.eval() self.net_D2.eval() self.net_D3.eval() # Iterate over data. for self.batch_id, batch in enumerate(self.dataloaders['val'], 0): with torch.no_grad(): self._forward_pass(batch) self._collect_running_batch_states() self._collect_epoch_states() ########### Update_Checkpoints ########### ########################################## self._update_checkpoints()
def _create_target_network(self): if self._body_type == 'ff': self._target_net = Head(ffBody(self._obs_num_features_or_obs_in_channels, self._fc_hidden_layer_size), self._output_actions) else: self._target_net = Head(convBody(self._obs_num_features_or_obs_in_channels, self._fc_hidden_layer_size), self._output_actions) set_requires_grad(self._target_net, False)
opt.lr = opt.lr * 0.1 for param_group in g_optimizer.param_groups: param_group['lr'] = opt.lr for param_group in d_optimizer.param_groups: param_group['lr'] = opt.lr for i, (X, Y) in enumerate(dataloader): X, Y = X.to(device), Y.to(device) # --------train Generator-------# X_fake = F(Y) X_rec = G(X_fake) Y_fake = G(X) Y_rec = F(Y_fake) set_requires_grad([D_X, D_Y], False) g_optimizer.zero_grad() G_ad_loss = torch.mean((D_X(X_fake) - 1)**2) F_ad_loss = torch.mean((D_Y(Y_fake) - 1)**2) G_cyc_loss = torch.mean((X.detach() - X_rec)**2) F_cyc_loss = torch.mean((Y.detach() - Y_rec)**2) if opt.lambda_idt > 0: G_idt_loss = torch.mean(torch.abs(G(X) - X)) * opt.lambda_idt F_idt_loss = torch.mean(torch.abs(F(Y) - Y)) * opt.lambda_idt else: G_idt_loss = 0 F_idt_loss = 0 gen_loss = G_ad_loss + F_ad_loss + opt.lambda_ * ( G_cyc_loss + F_cyc_loss + G_idt_loss + F_idt_loss)
elif mode == "c": training_models = [c_model] elif mode == "c-a": training_models = [c_model, actor] elif mode == "a": training_models = [actor] elif mode == "e": training_models = [] else: raise NotImplementedError ### Define Parameters ### training_params = [] for m in all_models: if m not in training_models: set_requires_grad(m, False) m.eval() else: training_params += list(m.parameters()) solver = None if len(training_params) > 0: solver = optim.Adam(training_params, lr=1e-3) ### Visual Planning & Acting ### if mode == "e": from eval import execute execute(all_models, kwargs["env"], savepath, eval_hp=kwargs["eval_hp"]) else: train(all_models, training_models, solver, training_params, log_every, **kwargs)
def train(): # Fix Seed for Reproducibility # torch.manual_seed(9) if torch.cuda.is_available(): torch.cuda.manual_seed(9) # Samples, Weights, and Plots Path # paths = [config.samples_path, config.weights_path, config.plots_path] paths = [make_dirs(path) for path in paths] # Prepare Data Loader # train_loader_selfie, train_loader_anime = get_selfie2anime_loader( 'train', config.batch_size) total_batch = max(len(train_loader_selfie), len(train_loader_anime)) test_loader_selfie, test_loader_anime = get_selfie2anime_loader( 'test', config.val_batch_size) # Prepare Networks # D_A = Discriminator(num_layers=7) D_B = Discriminator(num_layers=7) L_A = Discriminator(num_layers=5) L_B = Discriminator(num_layers=5) G_A2B = Generator(image_size=config.crop_size, num_blocks=config.num_blocks) G_B2A = Generator(image_size=config.crop_size, num_blocks=config.num_blocks) networks = [D_A, D_B, L_A, L_B, G_A2B, G_B2A] for network in networks: network.to(device) # Loss Function # Adversarial_loss = nn.MSELoss() Cycle_loss = nn.L1Loss() BCE_loss = nn.BCEWithLogitsLoss() # Optimizers # D_optim = torch.optim.Adam(chain(D_A.parameters(), D_B.parameters(), L_A.parameters(), L_B.parameters()), lr=config.lr, betas=(0.5, 0.999), weight_decay=0.0001) G_optim = torch.optim.Adam(chain(G_A2B.parameters(), G_B2A.parameters()), lr=config.lr, betas=(0.5, 0.999), weight_decay=0.0001) D_optim_scheduler = get_lr_scheduler(D_optim) G_optim_scheduler = get_lr_scheduler(G_optim) # Rho Clipper to constraint the value of rho in AdaILN and ILN # Rho_Clipper = RhoClipper(0, 1) # Lists # D_losses = [] G_losses = [] # Train # print("Training U-GAT-IT started with total epoch of {}.".format( config.num_epochs)) for epoch in range(config.num_epochs): for i, (selfie, anime) in enumerate( zip(train_loader_selfie, train_loader_anime)): # Data Preparation # real_A = selfie.to(device) real_B = anime.to(device) # Initialize Optimizers # D_optim.zero_grad() G_optim.zero_grad() ####################### # Train Discriminator # ####################### set_requires_grad([D_A, D_B, L_A, L_B], requires_grad=True) # Forward Data # fake_B, _, _ = G_A2B(real_A) fake_A, _, _ = G_B2A(real_B) G_real_A, G_real_A_cam, _ = D_A(real_A) L_real_A, L_real_A_cam, _ = L_A(real_A) G_real_B, G_real_B_cam, _ = D_B(real_B) L_real_B, L_real_B_cam, _ = L_B(real_B) G_fake_A, G_fake_A_cam, _ = D_A(fake_A) L_fake_A, L_fake_A_cam, _ = L_A(fake_A) G_fake_B, G_fake_B_cam, _ = D_B(fake_B) L_fake_B, L_fake_B_cam, _ = L_B(fake_B) # Adversarial Loss of Discriminator # real_labels = torch.ones(G_real_A.shape).to(device) D_ad_real_loss_GA = Adversarial_loss(G_real_A, real_labels) fake_labels = torch.zeros(G_fake_A.shape).to(device) D_ad_fake_loss_GA = Adversarial_loss(G_fake_A, fake_labels) D_ad_loss_GA = D_ad_real_loss_GA + D_ad_fake_loss_GA real_labels = torch.ones(G_real_A_cam.shape).to(device) D_ad_cam_real_loss_GA = Adversarial_loss(G_real_A_cam, real_labels) fake_labels = torch.zeros(G_fake_A_cam.shape).to(device) D_ad_cam_fake_loss_GA = Adversarial_loss(G_fake_A_cam, fake_labels) D_ad_cam_loss_GA = D_ad_cam_real_loss_GA + D_ad_cam_fake_loss_GA real_labels = torch.ones(G_real_B.shape).to(device) D_ad_real_loss_GB = Adversarial_loss(G_real_B, real_labels) fake_labels = torch.zeros(G_fake_B.shape).to(device) D_ad_fake_loss_GB = Adversarial_loss(G_fake_B, fake_labels) D_ad_loss_GB = D_ad_real_loss_GB + D_ad_fake_loss_GB real_labels = torch.ones(G_real_B_cam.shape).to(device) D_ad_cam_real_loss_GB = Adversarial_loss(G_real_B_cam, real_labels) fake_labels = torch.zeros(G_fake_B_cam.shape).to(device) D_ad_cam_fake_loss_GB = Adversarial_loss(G_fake_B_cam, fake_labels) D_ad_cam_loss_GB = D_ad_cam_real_loss_GB + D_ad_cam_fake_loss_GB # Adversarial Loss of L # real_labels = torch.ones(L_real_A.shape).to(device) D_ad_real_loss_LA = Adversarial_loss(L_real_A, real_labels) fake_labels = torch.zeros(L_fake_A.shape).to(device) D_ad_fake_loss_LA = Adversarial_loss(L_fake_A, fake_labels) D_ad_loss_LA = D_ad_real_loss_LA + D_ad_fake_loss_LA real_labels = torch.ones(L_real_A_cam.shape).to(device) D_ad_cam_real_loss_LA = Adversarial_loss(L_real_A_cam, real_labels) fake_labels = torch.zeros(L_fake_A_cam.shape).to(device) D_ad_cam_fake_loss_LA = Adversarial_loss(L_fake_A_cam, fake_labels) D_ad_cam_loss_LA = D_ad_cam_real_loss_LA + D_ad_cam_fake_loss_LA real_labels = torch.ones(L_real_B.shape).to(device) D_ad_real_loss_LB = Adversarial_loss(L_real_B, real_labels) fake_labels = torch.zeros(L_fake_B.shape).to(device) D_ad_fake_loss_LB = Adversarial_loss(L_fake_B, fake_labels) D_ad_loss_LB = D_ad_real_loss_LB + D_ad_fake_loss_LB real_labels = torch.ones(L_real_B_cam.shape).to(device) D_ad_cam_real_loss_LB = Adversarial_loss(L_real_B_cam, real_labels) fake_labels = torch.zeros(L_fake_B_cam.shape).to(device) D_ad_cam_fake_loss_LB = Adversarial_loss(L_fake_B_cam, fake_labels) D_ad_cam_loss_LB = D_ad_cam_real_loss_LB + D_ad_cam_fake_loss_LB # Calculate Each Discriminator Loss # D_loss_A = D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA D_loss_B = D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB # Calculate Total Discriminator Loss # D_loss = D_loss_A + D_loss_B # Back Propagation and Update # D_loss.backward() D_optim.step() ################### # Train Generator # ################### set_requires_grad([D_A, D_B, L_A, L_B], requires_grad=False) # Forward Data # fake_B, fake_B_cam, _ = G_A2B(real_A) fake_A, fake_A_cam, _ = G_B2A(real_B) fake_ABA, _, _ = G_B2A(fake_B) fake_BAB, _, _ = G_A2B(fake_A) fake_A2A, fake_A2A_cam, _ = G_A2B(real_A) fake_B2B, fake_B2B_cam, _ = G_B2A(real_B) G_fake_A, G_fake_A_cam, _ = D_A(fake_A) L_fake_A, L_fake_A_cam, _ = L_A(fake_A) G_fake_B, G_fake_B_cam, _ = D_B(fake_B) L_fake_B, L_fake_B_cam, _ = L_B(fake_B) # Adversarial Loss of Generator # real_labels = torch.ones(G_fake_A.shape).to(device) G_adv_fake_loss_A = Adversarial_loss(G_fake_A, real_labels) real_labels = torch.ones(G_fake_A_cam.shape).to(device) G_adv_cam_fake_loss_A = Adversarial_loss(G_fake_A_cam, real_labels) G_adv_loss_A = G_adv_fake_loss_A + G_adv_cam_fake_loss_A real_labels = torch.ones(G_fake_B.shape).to(device) G_adv_fake_loss_B = Adversarial_loss(G_fake_B, real_labels) real_labels = torch.ones(G_fake_B_cam.shape).to(device) G_adv_cam_fake_loss_B = Adversarial_loss(G_fake_B_cam, real_labels) G_adv_loss_B = G_adv_fake_loss_B + G_adv_cam_fake_loss_B # Adversarial Loss of L # real_labels = torch.ones(L_fake_A.shape).to(device) L_adv_fake_loss_A = Adversarial_loss(L_fake_A, real_labels) real_labels = torch.ones(L_fake_A_cam.shape).to(device) L_adv_cam_fake_loss_A = Adversarial_loss(L_fake_A_cam, real_labels) L_adv_loss_A = L_adv_fake_loss_A + L_adv_cam_fake_loss_A real_labels = torch.ones(L_fake_B.shape).to(device) L_adv_fake_loss_B = Adversarial_loss(L_fake_B, real_labels) real_labels = torch.ones(L_fake_B_cam.shape).to(device) L_adv_cam_fake_loss_B = Adversarial_loss(L_fake_B_cam, real_labels) L_adv_loss_B = L_adv_fake_loss_B + L_adv_cam_fake_loss_B # Cycle Consistency Loss # G_recon_loss_A = Cycle_loss(fake_ABA, real_A) G_recon_loss_B = Cycle_loss(fake_BAB, real_B) G_identity_loss_A = Cycle_loss(fake_A2A, real_A) G_identity_loss_B = Cycle_loss(fake_B2B, real_B) G_cycle_loss_A = G_recon_loss_A + G_identity_loss_A G_cycle_loss_B = G_recon_loss_B + G_identity_loss_B # CAM Loss # real_labels = torch.ones(fake_A_cam.shape).to(device) G_cam_real_loss_A = BCE_loss(fake_A_cam, real_labels) fake_labels = torch.zeros(fake_A2A_cam.shape).to(device) G_cam_fake_loss_A = BCE_loss(fake_A2A_cam, fake_labels) G_cam_loss_A = G_cam_real_loss_A + G_cam_fake_loss_A real_labels = torch.ones(fake_B_cam.shape).to(device) G_cam_real_loss_B = BCE_loss(fake_B_cam, real_labels) fake_labels = torch.zeros(fake_B2B_cam.shape).to(device) G_cam_fake_loss_B = BCE_loss(fake_B2B_cam, fake_labels) G_cam_loss_B = G_cam_real_loss_B + G_cam_fake_loss_B # Calculate Each Generator Loss # G_loss_A = G_adv_loss_A + L_adv_loss_A + config.lambda_cycle * G_cycle_loss_A + config.lambda_cam * G_cam_loss_A G_loss_B = G_adv_loss_B + L_adv_loss_B + config.lambda_cycle * G_cycle_loss_B + config.lambda_cam * G_cam_loss_B # Calculate Total Generator Loss # G_loss = G_loss_A + G_loss_B # Back Propagation and Update # G_loss.backward() G_optim.step() # Apply Rho Clipper to Generators # G_A2B.apply(Rho_Clipper) G_B2A.apply(Rho_Clipper) # Add items to Lists # D_losses.append(D_loss.item()) G_losses.append(G_loss.item()) #################### # Print Statistics # #################### if (i + 1) % config.print_every == 0: print( "U-GAT-IT | Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}" .format(epoch + 1, config.num_epochs, i + 1, total_batch, np.average(D_losses), np.average(G_losses))) # Save Sample Images # save_samples(test_loader_selfie, G_A2B, epoch, config.samples_path) # Adjust Learning Rate # D_optim_scheduler.step() G_optim_scheduler.step() # Save Model Weights # if (epoch + 1) % config.save_every == 0: torch.save( D_A.state_dict(), os.path.join(config.weights_path, 'U-GAT-IT_D_A_Epoch_{}.pkl'.format(epoch + 1))) torch.save( D_B.state_dict(), os.path.join(config.weights_path, 'U-GAT-IT_D_B_Epoch_{}.pkl'.format(epoch + 1))) torch.save( L_A.state_dict(), os.path.join(config.weights_path, 'U-GAT-IT_L_A_Epoch_{}.pkl'.format(epoch + 1))) torch.save( L_B.state_dict(), os.path.join(config.weights_path, 'U-GAT-IT_L_B_Epoch_{}.pkl'.format(epoch + 1))) torch.save( G_A2B.state_dict(), os.path.join(config.weights_path, 'U-GAT-IT_G_A2B_Epoch_{}.pkl'.format(epoch + 1))) torch.save( G_B2A.state_dict(), os.path.join(config.weights_path, 'U-GAT-IT_G_B2A_Epoch_{}.pkl'.format(epoch + 1))) # Plot Losses # plot_losses(D_losses, G_losses, config.num_epochs, config.plots_path) # Make a GIF file # make_gifs_train('U-GAT-IT', config.samples_path) print("Training finished.")
def train_srgans(train_loader, val_loader, generator, discriminator, device, args): # Loss Function # criterion_Perceptual = PerceptualLoss(args.model).to(device) # For SRGAN # criterion_MSE = nn.MSELoss() criterion_TV = TVLoss() # For ESRGAN # criterion_BCE = nn.BCEWithLogitsLoss() criterion_Content = nn.L1Loss() # Optimizers # D_optim = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.9, 0.999)) G_optim = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(0.9, 0.999)) D_optim_scheduler = get_lr_scheduler(D_optim, args) G_optim_scheduler = get_lr_scheduler(G_optim, args) # Lists # D_losses, G_losses = list(), list() # Train # print("Training {} started with total epoch of {}.".format( str(args.model).upper(), args.num_epochs)) for epoch in range(args.num_epochs): for i, (high, low) in enumerate(train_loader): discriminator.train() if args.model == "srgan": generator.train() # Data Preparation # high = high.to(device) low = low.to(device) # Initialize Optimizers # D_optim.zero_grad() G_optim.zero_grad() ####################### # Train Discriminator # ####################### set_requires_grad(discriminator, requires_grad=True) # Generate Fake HR Images # fake_high = generator(low) if args.model == 'srgan': # Forward Data # prob_real = discriminator(high) prob_fake = discriminator(fake_high.detach()) # Calculate Total Discriminator Loss # D_loss = 1 - prob_real.mean() + prob_fake.mean() elif args.model == 'esrgan': # Forward Data # prob_real = discriminator(high) prob_fake = discriminator(fake_high.detach()) # Relativistic Discriminator # diff_r2f = prob_real - prob_fake.mean() diff_f2r = prob_fake - prob_real.mean() # Labels # real_labels = torch.ones(diff_r2f.size()).to(device) fake_labels = torch.zeros(diff_f2r.size()).to(device) # Adversarial Loss # D_loss_real = criterion_BCE(diff_r2f, real_labels) D_loss_fake = criterion_BCE(diff_f2r, fake_labels) # Calculate Total Discriminator Loss # D_loss = (D_loss_real + D_loss_fake).mean() # Back Propagation and Update # D_loss.backward() D_optim.step() ################### # Train Generator # ################### set_requires_grad(discriminator, requires_grad=False) if args.model == 'srgan': # Adversarial Loss # prob_fake = discriminator(fake_high).mean() G_loss_adversarial = torch.mean(1 - prob_fake) G_loss_mse = criterion_MSE(fake_high, high) # Perceptual Loss # lambda_perceptual = 6e-3 G_loss_perceptual = criterion_Perceptual(fake_high, high) # Total Variation Loss # G_loss_tv = criterion_TV(fake_high) # Calculate Total Generator Loss # G_loss = args.lambda_adversarial * G_loss_adversarial + G_loss_mse + lambda_perceptual * G_loss_perceptual + args.lambda_tv * G_loss_tv elif args.model == 'esrgan': # Forward Data # prob_real = discriminator(high) prob_fake = discriminator(fake_high) # Relativistic Discriminator # diff_r2f = prob_real - prob_fake.mean() diff_f2r = prob_fake - prob_real.mean() # Labels # real_labels = torch.ones(diff_r2f.size()).to(device) fake_labels = torch.zeros(diff_f2r.size()).to(device) # Adversarial Loss # G_loss_bce_real = criterion_BCE(diff_f2r, real_labels) G_loss_bce_fake = criterion_BCE(diff_r2f, fake_labels) G_loss_bce = (G_loss_bce_real + G_loss_bce_fake).mean() # Perceptual Loss # lambda_perceptual = 1e-2 G_loss_perceptual = criterion_Perceptual(fake_high, high) # Content Loss # G_loss_content = criterion_Content(fake_high, high) # Calculate Total Generator Loss # G_loss = args.lambda_bce * G_loss_bce + lambda_perceptual * G_loss_perceptual + args.lambda_content * G_loss_content # Back Propagation and Update # G_loss.backward() G_optim.step() # Add items to Lists # D_losses.append(D_loss.item()) G_losses.append(G_loss.item()) #################### # Print Statistics # #################### if (i + 1) % args.print_every == 0: print( "{} | Epoch [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}" .format( str(args.model).upper(), epoch + 1, args.num_epochs, i + 1, len(train_loader), np.average(D_losses), np.average(G_losses))) # Save Sample Images # sample_images(val_loader, args.batch_size, args.scale_factor, generator, epoch, args.samples_path, device) # Adjust Learning Rate # D_optim_scheduler.step() G_optim_scheduler.step() # Save Model Weights and Inference # if (epoch + 1) % args.save_every == 0: torch.save( generator.state_dict(), os.path.join( args.weights_path, '{}_Epoch_{}.pkl'.format(generator.__class__.__name__, epoch + 1))) inference(val_loader, generator, args.upscale_factor, epoch, args.inference_path, device)
def train(self, verbose=True): # set_detect_anomaly(True) torch.cuda.empty_cache() # load last iteration if training was started but not finished if len(listdir("saves7")) > 0: # the [4::-3] is because of the file name format, with the number of each checkpoint at these points self.loadModel( sorted(listdir('saves7'), key=lambda x: int(x[4:-3]))[-1][4:-3]) # self.loadModel(120) self.checkpoint = int( sorted(listdir('saves7'), key=lambda x: int(x[4:-3]))[-1][4:-3]) print("Loading from checkpoint: ", self.checkpoint) self.checkpoint = self.checkpoint + 1 print("New checkpoint starts at: ", self.checkpoint) else: self.StyleGan.init_weights() self.checkpoint = 0 # self.loadModel(1050) # self.checkpoint = 0 # print(self.StyleGan) # utils.init_weights(self.StyleGan) # utils.set_requires_grad(self.StyleGan, True) # training loop for epoch in range(0, self.epochs): for batch_num, batch in enumerate(self.dataLoader): # if batch_num % 50 == 0: # # generated_images = self.StyleGan.generator(self.constant_style, self.constant_noise) # img_grid = make_grid(generated_images) # self.tensorboard_summary.add_image(f'generated_image{self.checkpoint}', img_grid) # del generated_images # del img_grid batch = batch[0].expand(-1, 3, -1, -1).to(self.device) batch.requires_grad = True # print("OMG", batch.shape) if batch.shape[0] != 128: print("SKIPPING") continue w_space = [] # Train Discriminator: maximize log(D(x)) + log(1 - D(G(z))) utils.set_requires_grad(self.StyleGan.discriminator, True) self.StyleGan.discriminator.train() self.StyleGan.generator.eval() utils.set_requires_grad(self.StyleGan.generator, False) if np.random.random() < self.mixed_probability: style_noise = utils.createStyleMixedNoiseList( self.batch_size, self.latent_dim, self.num_layers, self.StyleGan.styleNetwork, self.device) else: style_noise = utils.createStyleNoiseList( self.batch_size, self.latent_dim, self.num_layers, self.StyleGan.styleNetwork, self.device) image_noise = utils.create_image_noise(self.batch_size, self.image_size, self.device) # style_noise = style_noise.half() # image_noise = image_noise.half() self.StyleGan.discriminatorOptimizer.zero_grad() real_labels = (torch.ones(self.batch_size) * 0.9).to( self.device) # transpose to match size with label tensor size discriminator_real_output = self.StyleGan.discriminator( batch).reshape(-1).to(self.device) # print("DIS real output", discriminator_real_output) discriminator_real_loss = self.loss_fn( discriminator_real_output, real_labels).mean() # print("DIS real loss", discriminator_real_loss) # print("DIS real labelsL", real_labels) del real_labels generated_images = self.StyleGan.generator( style_noise.detach(), image_noise.detach()).to(self.device) del style_noise del image_noise fake_labels = (torch.ones(self.batch_size) * 0.1).to( self.device) # transpose to match size with label tensor size discriminator_fake_output = self.StyleGan.discriminator( generated_images.detach()).reshape(-1).to(self.device) # print("DIS fake output", discriminator_fake_output) discriminator_fake_loss = self.loss_fn( discriminator_fake_output, fake_labels).mean() # print("DIS fake loss", discriminator_fake_loss) # print("DIS fake labels", fake_labels) del fake_labels del discriminator_fake_output discriminator_total_loss = discriminator_fake_loss + discriminator_real_loss # if batch_num % 100 == 0: # print("d real loss", discriminator_real_loss) # print("d fake loss", discriminator_fake_loss) # print("d real output", discriminator_real_output) # print("d fake output", discriminator_fake_output) # print("real + fake loss", discriminator_fake_loss + discriminator_real_loss) # print("total loss", discriminator_total_loss) # print("\n\n") # discriminator_accuracy = 0 # Apply Gradient Penalty every 4 steps # if batch_num % 4 == 0: # print("before gradien tpenalty", batch_num) discriminator_total_loss = discriminator_total_loss + utils.gradientPenalty( batch, discriminator_real_output, self.device) del discriminator_real_output if isnan(discriminator_total_loss): print("IS NAN discriminator") break if self.apex_available: with amp.scale_loss(discriminator_total_loss, self.StyleGan.discriminatorOptimizer ) as scaled_loss: scaled_loss.backward() else: discriminator_total_loss.backward() torch.nn.utils.clip_grad_norm_( self.StyleGan.discriminator.parameters(), 5, norm_type=2) # for p in self.StyleGan.discriminator.parameters(): # p.data.clamp_(-0.01, 0.01) self.StyleGan.discriminatorOptimizer.step() # Train Generator: maximize log(D(G(z))) utils.set_requires_grad(self.StyleGan.discriminator, False) self.StyleGan.discriminator.eval() self.StyleGan.generator.train() utils.set_requires_grad(self.StyleGan.generator, True) if np.random.random() < self.mixed_probability: style_noise = utils.createStyleMixedNoiseList( self.batch_size, self.latent_dim, self.num_layers, self.StyleGan.styleNetwork, self.device) else: style_noise = utils.createStyleNoiseList( self.batch_size, self.latent_dim, self.num_layers, self.StyleGan.styleNetwork, self.device) image_noise = utils.create_image_noise(self.batch_size, self.image_size, self.device) self.StyleGan.generatorOptimizer.zero_grad() generator_labels = torch.ones(self.batch_size).to(self.device) # Generate images generated_images = self.StyleGan.generator( style_noise, image_noise).to(self.device) del image_noise # utils.showImage(batch) # utils.showImage(generated_images) generator_output = self.StyleGan.discriminator( generated_images).reshape(-1).to(self.device) # print("gen output", generator_output) generator_loss = self.loss_fn(generator_output, generator_labels).mean() # print("gen loss", generator_loss) # print("gen labels", generator_labels) # if batch_num % 100 == 0: # print(generator_loss) if isnan(generator_loss): print("isnan generator") break del generator_output del generator_labels generator_loss_no_pl = generator_loss # Apply Path Length Regularization every 16 steps if batch_num % 10 == 0: num_pixels = generated_images.shape[ 2] * generated_images.shape[3] noise_to_add = (torch.randn(generated_images.shape) / math.sqrt(num_pixels)).to(self.device) outputs = (generated_images * noise_to_add) # del generated_images pl_gradient = grad(outputs=outputs, inputs=style_noise, grad_outputs=torch.ones( outputs.shape).to(self.device), create_graph=True, retain_graph=True, only_inputs=True)[0] del num_pixels del noise_to_add del outputs pl_length = torch.sqrt(torch.sum( torch.square(pl_gradient))) if self.average_pl_length is not None: pl_regularizer = ((pl_length - self.average_pl_length)**2).mean() else: pl_regularizer = (pl_length**2).mean() del pl_gradient del style_noise # print("PL LENGTH IS: ", pl_length) if self.average_pl_length == None: self.average_pl_length = pl_length.detach().item() else: self.average_pl_length = self.average_pl_length * self.pl_beta + ( 1 - self.pl_beta) * pl_length.detach().item() # self.average_pl_length = pl_length del pl_length generator_loss = generator_loss + pl_regularizer # print(self.average_pl_length) # print(batch_num) # print(batch_num, "HI") if self.apex_available: with amp.scale_loss( generator_loss, self.StyleGan.generatorOptimizer) as scaled_loss: scaled_loss.backward() else: generator_loss.backward(retain_graph=True) torch.nn.utils.clip_grad_norm_( self.StyleGan.generator.parameters(), 5, norm_type=2) # for p in self.StyleGan.generator.parameters(): # p.data.clamp_(-0.01, 0.01) # generator_accuracy = generator_loss.argmax == generator_labels # TODO self.StyleGan.generatorOptimizer.step() # Update MappingNetwork weights # if self.apex_available: # with amp.scale_loss(generator_loss, self.StyleGan.generatorOptimizer) as scaled_loss: # scaled_loss.backward() # else: # generator_loss.backward(retain_graph = True) if verbose == True: if batch_num % 100 == 0 and batch_num != 0: # print("average path length is: ", self.average_pl_length) print("Checkpoint") print("Batch: ", batch_num) print("Path Length Mean: ", self.average_pl_length) print("Discriminator Mean Real Loss: ", discriminator_real_loss.item()) print("Discriminator Mean Fake Loss: ", discriminator_fake_loss.item()) print("Discriminator Total Loss: ", discriminator_total_loss.item()) # print("Discriminator Accuracy: ", discriminator_accuracy) print("Generator Loss (no pl)", generator_loss_no_pl) print("Generator Loss: ", generator_loss.item()) # print("Generator Accuracy: ", generator_accuracy) print("PL difference:", pl_regularizer.item()) if batch_num % 100 == 0 and batch_num != 0: print("Current Checkpoint is: ", self.checkpoint) img_grid = make_grid(generated_images) self.tensorboard_summary.add_scalar( 'Path Length Mean', self.average_pl_length, self.checkpoint) self.tensorboard_summary.add_scalar( 'Discriminator Mean Real Loss ', discriminator_real_loss, self.checkpoint) self.tensorboard_summary.add_scalar( 'Discriminator Mean Fake Loss ', discriminator_fake_loss, self.checkpoint) self.tensorboard_summary.add_scalar( 'Discriminator Total Loss ', discriminator_total_loss.item(), self.checkpoint) self.tensorboard_summary.add_scalar( 'Generator Loss', generator_loss.item(), self.checkpoint) self.tensorboard_summary.add_scalar( 'Path Length Difference', pl_regularizer.item(), self.checkpoint) self.tensorboard_summary.add_scalar( 'Generator Loss (No PL)', generator_loss_no_pl.item(), self.checkpoint) self.tensorboard_summary.add_image( f'generated_image{self.checkpoint}', img_grid) # self.tensorboard_summary.add_scalar("D") del generated_images del img_grid # self.tensorboard_summary.add_scalar('Generator Weight', self.StyleGan.generator.we, self.checkpoint) # self.tensorboard_summary.add_scalar('Generator Weight', generator_loss_no_pl.item(), self.checkpoint) del generator_loss_no_pl del discriminator_total_loss del generator_loss del pl_regularizer del discriminator_real_loss del discriminator_fake_loss self.saveModel(self.checkpoint) self.checkpoint = self.checkpoint + 1 # if steps > 20000: # self.StyleGan.EMA(0.99) # Right now, an epoch is never achieved # # Create a checkpoint at the end of an epoch # print("Current Checkpoint is: ", self.checkpoint) # self.tensorboard_summary.add_scalar('Path Length Mean', self.average_pl_length, self.checkpoint) # self.tensorboard_summary.add_scalar('Discriminator Mean Real Loss ', # discriminator_real_loss, self.checkpoint) # self.tensorboard_summary.add_scalar('Discriminator Mean Fake Loss ', # discriminator_fake_loss, self.checkpoint) # self.tensorboard_summary.add_scalar('Discriminator Total Loss ', discriminator_total_loss.item(), # self.checkpoint) # self.tensorboard_summary.add_scalar('Generator Loss', generator_loss.item(), self.checkpoint) # self.tensorboard_summary.add_scalar('Path Length Difference', pl_regularizer.item(), self.checkpoint) # self.tensorboard_summary.add_scalar('Generator Loss (No PL)', generator_loss_no_pl.item()) # del generator_loss_no_pl # del discriminator_total_loss # del generator_loss # del pl_regularizer # del discriminator_fake_loss # del discriminator_real_loss # self.saveModel(self.checkpoint) # self.checkpoint = self.checkpoint + 1 # Close TensorBoard at the end self.tensorboard_summary.close()
def step(self, i): self.G.train() scalars = defaultdict(list) xs_real = [] xs_fake = [] zs = [] ############################################################# # train D ############################################################# set_requires_grad(self.D, True) self.optim_D.zero_grad(set_to_none=True) for j in gradient_accumulation(self.cfg.solver.num_accumulation, True, (self.G, self.D)): # input data x_real, m_real = self.fetch_reals(next(self.loader)) xs_real.append({"depth": x_real, "mask": m_real}) B = x_real.shape[0] # sample z z = self.sample_latents(B) zs.append(z) loss_D = 0 # discriminator loss with torch.cuda.amp.autocast(enabled=self.enable_amp): synth = self.G(latent=z) xs_fake.append(synth) # augment x_real_aug = self.A(x_real).detach().requires_grad_() x_fake_aug = self.A(synth["depth"]).detach() # forward D y_real = self.D(x_real_aug) y_fake = self.D(x_fake_aug) scalars["loss/D/output/real"].append(y_real.mean().detach()) scalars["loss/D/output/fake"].append(y_fake.mean().detach()) # adversarial loss loss_GAN = self.criterion["gan"](y_real, y_fake, "D") loss_D += self.loss_weight["gan"] * loss_GAN scalars["loss/D/adversarial"].append(loss_GAN.detach()) # r1 gradient penalty if "gp" in self.criterion: (grads, ) = torch.autograd.grad( outputs=self.scaler.scale(y_real.sum()), inputs=[x_real_aug], create_graph=True, only_inputs=True, ) # unscale grads = grads / self.scaler.get_scale() with torch.cuda.amp.autocast(enabled=self.enable_amp): r1_penalty = (grads**2).sum(dim=[1, 2, 3]).mean() scalars["loss/D/gradient_penalty"].append( r1_penalty.detach()) loss_D += (self.loss_weight["gp"] / 2) * r1_penalty loss_D += 0.0 * y_real.squeeze()[0] loss_D /= float(self.cfg.solver.num_accumulation) self.scaler.scale(loss_D).backward() # update D parameters self.scaler.step(self.optim_D) ############################################################# # train G ############################################################# set_requires_grad(self.D, False) self.optim_G.zero_grad(set_to_none=True) for j in gradient_accumulation(self.cfg.solver.num_accumulation, True, (self.G, self.D)): loss_G = 0 # generator loss with torch.cuda.amp.autocast(enabled=self.enable_amp): # augment x_real_aug = self.A(xs_real[j]["depth"]).detach() x_fake_aug = self.A(xs_fake[j]["depth"]) # forward D y_real = self.D(x_real_aug) y_fake = self.D(x_fake_aug) # adversarial loss loss_GAN = self.criterion["gan"](y_real, y_fake, "G") loss_G += self.loss_weight["gan"] * loss_GAN scalars["loss/G/adversarial"].append(loss_GAN.detach()) # path length regularization if "pl" in self.criterion: # forward G with smaller batch B_pl = len(xs_real[j]["depth"]) // 2 z_pl = self.sample_latents(B_pl).requires_grad_() # perturb images with torch.cuda.amp.autocast(enabled=self.enable_amp): synth_pl = self.G(latent=z_pl) x_pl = synth_pl["depth"] noise_pl = torch.randn_like(x_pl) noise_pl /= np.sqrt(np.prod(x_pl.shape[2:])) outputs = (x_pl * noise_pl).sum() (grads, ) = torch.autograd.grad( outputs=self.scaler.scale(outputs), inputs=[z_pl], create_graph=True, only_inputs=True, ) # unscale grads = grads / self.scaler.get_scale() with torch.cuda.amp.autocast(enabled=self.enable_amp): # compute |J*y| pl_lengths = grads.pow(2).sum(dim=-1) pl_lengths = torch.sqrt(pl_lengths) # ema of |J*y| pl_ema = self.pl_ema.lerp(pl_lengths.mean(), 0.01) self.pl_ema.copy_(pl_ema.detach()) # calculate (|J*y|-a)^2 pl_penalty = (pl_lengths - pl_ema).pow(2).mean() scalars["loss/G/path_length/baseline"].append( self.pl_ema.detach()) scalars["loss/G/path_length"].append(pl_penalty.detach()) loss_G += self.loss_weight["pl"] * pl_penalty loss_G += 0.0 * x_pl[0, 0, 0, 0] loss_G /= float(self.cfg.solver.num_accumulation) self.scaler.scale(loss_G).backward() # update G parameters self.scaler.step(self.optim_G) self.scaler.update() ema_inplace(self.G_ema, self.G.module, self.ema_decay) # gather scalars from all devices for key, scalar_list in scalars.items(): scalar = torch.mean(torch.stack(scalar_list)) dist.all_reduce(scalar) # sum over gpus scalar /= dist.get_world_size() scalars[key] = scalar.item() return scalars