def train(): torch.manual_seed(1337) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Config batch_size = 32 image_size = 256 learning_rate = 1e-4 beta1, beta2 = (.5, .99) weight_decay = 1e-4 epochs = 1000 # Models netD = Discriminator().to(device) netG = Generator().to(device) # Here you should load the pretrained G netG.load_state_dict(torch.load("./checkpoints/pretrained_netG.pth").state_dict()) optimizerD = AdamW(netD.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay) optimizerG = AdamW(netG.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay) scaler = torch.cuda.amp.GradScaler() # Labels cartoon_labels = torch.ones (batch_size, 1, image_size // 4, image_size // 4).to(device) fake_labels = torch.zeros(batch_size, 1, image_size // 4, image_size // 4).to(device) # Loss functions content_loss = ContentLoss().to(device) adv_loss = AdversialLoss(cartoon_labels, fake_labels).to(device) BCE_loss = nn.BCEWithLogitsLoss().to(device) # Dataloaders real_dataloader = get_dataloader("./datasets/real_images/flickr30k_images/", size = image_size, bs = batch_size) cartoon_dataloader = get_dataloader("./datasets/cartoon_images_smoothed/Studio Ghibli", size = image_size, bs = batch_size, trfs=get_pair_transforms(image_size)) # --------------------------------------------------------------------------------------------- # # Training Loop # Lists to keep track of progress img_list = [] G_losses = [] D_losses = [] iters = 0 tracked_images = next(iter(real_dataloader)).to(device) print("Starting Training Loop...") # For each epoch. for epoch in range(epochs): print("training epoch ", epoch) # For each batch in the dataloader. for i, (cartoon_edge_data, real_data) in enumerate(zip(cartoon_dataloader, real_dataloader)): ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### # Reset Discriminator gradient. netD.zero_grad() for param in netD.parameters(): param.requires_grad = True # Format batch. cartoon_data = cartoon_edge_data[:, :, :, :image_size].to(device) edge_data = cartoon_edge_data[:, :, :, image_size:].to(device) real_data = real_data.to(device) with torch.cuda.amp.autocast(): # Generate image generated_data = netG(real_data) # Forward pass all batches through D. cartoon_pred = netD(cartoon_data) #.view(-1) edge_pred = netD(edge_data) #.view(-1) generated_pred = netD(generated_data.detach()) #.view(-1) # Calculate discriminator loss on all batches. errD = adv_loss(cartoon_pred, generated_pred, edge_pred) # Calculate gradients for D in backward pass scaler.scale(errD).backward() D_x = cartoon_pred.mean().item() # Should be close to 1 # Update D scaler.step(optimizerD) ############################ # (2) Update G network: maximize log(D(G(z))) ########################### # Reset Generator gradient. netG.zero_grad() for param in netD.parameters(): param.requires_grad = False with torch.cuda.amp.autocast(): # Since we just updated D, perform another forward pass of all-fake batch through D generated_pred = netD(generated_data) #.view(-1) # Calculate G's loss based on this output errG = BCE_loss(generated_pred, cartoon_labels) + content_loss(generated_data, real_data) # Calculate gradients for G scaler.scale(errG).backward() D_G_z2 = generated_pred.mean().item() # Should be close to 1 # Update G scaler.step(optimizerG) scaler.update() # ---------------------------------------------------------------------------------------- # # Save Losses for plotting later G_losses.append(errG.item()) D_losses.append(errD.item()) # Check how the generator is doing by saving G's output on tracked_images if iters % 200 == 0: with torch.no_grad(): fake = netG(tracked_images) vutils.save_image(unnormalize(fake), f"images/{epoch}_{i}.png", padding=2) with open("images/log.txt", "a+") as f: f.write(f"{datetime.now().isoformat(' ', 'seconds')}\tD: {np.mean(D_losses)}\tG: {np.mean(G_losses)}\n") D_losses = [] G_losses = [] if iters % 1000 == 0: torch.save(netG.state_dict(), f"checkpoints/netG_e{epoch}_i{iters}_l{errG.item()}.pth") torch.save(netD.state_dict(), f"checkpoints/netD_e{epoch}_i{iters}_l{errG.item()}.pth") iters += 1
errDiscriminitor = [] errSegnet = [] errGenerator = [] for (x,y_bin) in data_wi_gt: x = x.transpose([0,3,1,2]) y = y_bin.transpose([0,3,1,2]) y = np.reshape(y,(batch_size,y[0].size,1,1))/7.0 y = torch.from_numpy(y).cuda().float() #print("counter: "+str(counter)) counter += 1 real_rgb = torch.from_numpy(x).cuda().float() #DISCRIMINATOR netD.zero_grad() #forward with desired distribution disc_ouput = netD(real_rgb).view(-1) errD_real = criterion(disc_ouput, ones) errD_real.backward() D_x = disc_ouput.mean().item() #gen distrub noise = torch.FloatTensor(batch_size, 4096, 1, 1).normal_(0, 1).cuda() rgb_generator_input = torch.cat((y,noise),1).cuda() fake_rgb = netG(rgb_generator_input) #forward with generated distribution disc_ouput = netD(fake_rgb.detach()).view(-1) errD_fake = criterion(disc_ouput, zero)
def train(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Config batch_size = 9 image_size = 256 learning_rate = 1e-3 beta1, beta2 = (.5, .99) weight_decay = 1e-3 epochs = 10 # Models netD = Discriminator().to(device) netG = Generator().to(device) optimizerD = AdamW(netD.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay) optimizerG = AdamW(netG.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay) # Labels cartoon_labels = torch.ones(batch_size, 1, image_size // 4, image_size // 4).to(device) fake_labels = torch.zeros(batch_size, 1, image_size // 4, image_size // 4).to(device) # Loss functions content_loss = ContentLoss(device) adv_loss = AdversialLoss(cartoon_labels, fake_labels) BCE_loss = nn.BCELoss().to(device) # Dataloaders real_dataloader = get_dataloader("./datasets/real_images", size=image_size, bs=batch_size) cartoon_dataloader = get_dataloader("./datasets/cartoon_images", size=image_size, bs=batch_size) edge_dataloader = get_dataloader("./datasets/cartoon_images_smooth", size=image_size, bs=batch_size) # --------------------------------------------------------------------------------------------- # # Training Loop # Lists to keep track of progress img_list = [] G_losses = [] D_losses = [] iters = 0 tracked_images = next(iter(real_dataloader))[0].to(device) print("Starting Training Loop...") # For each epoch. for epoch in range(epochs): # For each batch in the dataloader. for i, ((cartoon_data, _), (edge_data, _), (real_data, _)) in enumerate( zip(cartoon_dataloader, edge_dataloader, real_dataloader)): ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### # Reset Discriminator gradient. netD.zero_grad() # Format batch. cartoon_data = cartoon_data.to(device) edge_data = edge_data.to(device) real_data = real_data.to(device) # Generate image generated_data = netG(real_data) # Forward pass all batches through D. cartoon_pred = netD(cartoon_data) #.view(-1) edge_pred = netD(edge_data) #.view(-1) generated_pred = netD(generated_data) #.view(-1) print(generated_data.is_cuda, real_data.is_cuda) # Calculate discriminator loss on all batches. errD = adv_loss(cartoon_pred, generated_pred, edge_pred) # Calculate gradients for D in backward pass errD.backward() D_x = cartoon_pred.mean().item() # Should be close to 1 # Update D optimizerD.step() ############################ # (2) Update G network: maximize log(D(G(z))) ########################### # Reset Generator gradient. netG.zero_grad() # Since we just updated D, perform another forward pass of all-fake batch through D generated_pred = netD(generated_data) #.view(-1) # Calculate G's loss based on this output print(generated_data.is_cuda, real_data.is_cuda) print("generated_pred:", generated_pred.is_cuda, "cartoon_labels:", cartoon_labels.is_cuda) errG = BCE_loss(generated_pred, cartoon_labels) + content_loss( generated_data, real_data) # Calculate gradients for G errG.backward() D_G_z2 = generated_pred.mean().item() # Should be close to 1 # Update G optimizerG.step() # ---------------------------------------------------------------------------------------- # # Output training stats if i % 50 == 0: print( '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch, epochs, i, len(real_dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) # Save Losses for plotting later G_losses.append(errG.item()) D_losses.append(errD.item()) # Check how the generator is doing by saving G's output on tracked_images if (iters % 500 == 0) or ((epoch == epochs - 1) and (i == len(dataloader) - 1)): with torch.no_grad(): fake = netG(tracked_images).detach().cpu() img_list.append( vutils.make_grid(fake, padding=2, normalize=True)) iters += 1