示例#1
0
def train():
    torch.manual_seed(1337)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Config
    batch_size = 32
    image_size = 256
    learning_rate = 1e-4
    beta1, beta2 = (.5, .99)
    weight_decay = 1e-4
    epochs = 1000

    # Models
    netD = Discriminator().to(device)
    netG = Generator().to(device)
    # Here you should load the pretrained G
    netG.load_state_dict(torch.load("./checkpoints/pretrained_netG.pth").state_dict())

    optimizerD = AdamW(netD.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay)
    optimizerG = AdamW(netG.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay)

    scaler = torch.cuda.amp.GradScaler()

    # Labels
    cartoon_labels = torch.ones (batch_size, 1, image_size // 4, image_size // 4).to(device)
    fake_labels    = torch.zeros(batch_size, 1, image_size // 4, image_size // 4).to(device)

    # Loss functions
    content_loss = ContentLoss().to(device)
    adv_loss     = AdversialLoss(cartoon_labels, fake_labels).to(device)
    BCE_loss     = nn.BCEWithLogitsLoss().to(device)

    # Dataloaders
    real_dataloader    = get_dataloader("./datasets/real_images/flickr30k_images/",           size = image_size, bs = batch_size)
    cartoon_dataloader = get_dataloader("./datasets/cartoon_images_smoothed/Studio Ghibli",   size = image_size, bs = batch_size, trfs=get_pair_transforms(image_size))

    # --------------------------------------------------------------------------------------------- #
    # Training Loop

    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    tracked_images = next(iter(real_dataloader)).to(device)

    print("Starting Training Loop...")
    # For each epoch.
    for epoch in range(epochs):
        print("training epoch ", epoch)
        # For each batch in the dataloader.
        for i, (cartoon_edge_data, real_data) in enumerate(zip(cartoon_dataloader, real_dataloader)):

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            
            # Reset Discriminator gradient.
            netD.zero_grad()
            for param in netD.parameters():
                param.requires_grad = True

            # Format batch.
            cartoon_data   = cartoon_edge_data[:, :, :, :image_size].to(device)
            edge_data      = cartoon_edge_data[:, :, :, image_size:].to(device)
            real_data      = real_data.to(device)

            with torch.cuda.amp.autocast():
                # Generate image
                generated_data = netG(real_data)

                # Forward pass all batches through D.
                cartoon_pred   = netD(cartoon_data)      #.view(-1)
                edge_pred      = netD(edge_data)         #.view(-1)
                generated_pred = netD(generated_data.detach())    #.view(-1)

                # Calculate discriminator loss on all batches.
                errD = adv_loss(cartoon_pred, generated_pred, edge_pred)
            
            # Calculate gradients for D in backward pass
            scaler.scale(errD).backward()
            D_x = cartoon_pred.mean().item() # Should be close to 1

            # Update D
            scaler.step(optimizerD)


            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            
            # Reset Generator gradient.
            netG.zero_grad()
            for param in netD.parameters():
                param.requires_grad = False

            with torch.cuda.amp.autocast():
                # Since we just updated D, perform another forward pass of all-fake batch through D
                generated_pred = netD(generated_data) #.view(-1)

                # Calculate G's loss based on this output
                errG = BCE_loss(generated_pred, cartoon_labels) + content_loss(generated_data, real_data)

            # Calculate gradients for G
            scaler.scale(errG).backward()

            D_G_z2 = generated_pred.mean().item() # Should be close to 1
            
            # Update G
            scaler.step(optimizerG)

            scaler.update()
            
            # ---------------------------------------------------------------------------------------- #

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Check how the generator is doing by saving G's output on tracked_images
            if iters % 200 == 0:
                with torch.no_grad():
                    fake = netG(tracked_images)
                vutils.save_image(unnormalize(fake), f"images/{epoch}_{i}.png", padding=2)
                with open("images/log.txt", "a+") as f:
                    f.write(f"{datetime.now().isoformat(' ', 'seconds')}\tD: {np.mean(D_losses)}\tG: {np.mean(G_losses)}\n")
                D_losses = []
                G_losses = []

            if iters % 1000 == 0:
                torch.save(netG.state_dict(), f"checkpoints/netG_e{epoch}_i{iters}_l{errG.item()}.pth")
                torch.save(netD.state_dict(), f"checkpoints/netD_e{epoch}_i{iters}_l{errG.item()}.pth")

            iters += 1
示例#2
0
    errDiscriminitor = []
    errSegnet = []
    errGenerator = []
    for (x,y_bin) in data_wi_gt:
        x = x.transpose([0,3,1,2])
        y = y_bin.transpose([0,3,1,2])
        y = np.reshape(y,(batch_size,y[0].size,1,1))/7.0
        y = torch.from_numpy(y).cuda().float()
        #print("counter: "+str(counter))
        counter += 1


        real_rgb = torch.from_numpy(x).cuda().float()
        
        #DISCRIMINATOR
        netD.zero_grad()
        #forward with desired distribution
        disc_ouput = netD(real_rgb).view(-1)
        errD_real = criterion(disc_ouput, ones)
        errD_real.backward()
        D_x = disc_ouput.mean().item()


        #gen distrub
        noise = torch.FloatTensor(batch_size, 4096, 1, 1).normal_(0, 1).cuda()
        rgb_generator_input = torch.cat((y,noise),1).cuda()
        fake_rgb = netG(rgb_generator_input)

        #forward with generated distribution
        disc_ouput = netD(fake_rgb.detach()).view(-1)
        errD_fake = criterion(disc_ouput, zero)
示例#3
0
def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Config
    batch_size = 9
    image_size = 256
    learning_rate = 1e-3
    beta1, beta2 = (.5, .99)
    weight_decay = 1e-3
    epochs = 10

    # Models
    netD = Discriminator().to(device)
    netG = Generator().to(device)

    optimizerD = AdamW(netD.parameters(),
                       lr=learning_rate,
                       betas=(beta1, beta2),
                       weight_decay=weight_decay)
    optimizerG = AdamW(netG.parameters(),
                       lr=learning_rate,
                       betas=(beta1, beta2),
                       weight_decay=weight_decay)

    # Labels
    cartoon_labels = torch.ones(batch_size, 1, image_size // 4,
                                image_size // 4).to(device)
    fake_labels = torch.zeros(batch_size, 1, image_size // 4,
                              image_size // 4).to(device)

    # Loss functions
    content_loss = ContentLoss(device)
    adv_loss = AdversialLoss(cartoon_labels, fake_labels)
    BCE_loss = nn.BCELoss().to(device)

    # Dataloaders
    real_dataloader = get_dataloader("./datasets/real_images",
                                     size=image_size,
                                     bs=batch_size)
    cartoon_dataloader = get_dataloader("./datasets/cartoon_images",
                                        size=image_size,
                                        bs=batch_size)
    edge_dataloader = get_dataloader("./datasets/cartoon_images_smooth",
                                     size=image_size,
                                     bs=batch_size)

    # --------------------------------------------------------------------------------------------- #
    # Training Loop

    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    tracked_images = next(iter(real_dataloader))[0].to(device)

    print("Starting Training Loop...")
    # For each epoch.
    for epoch in range(epochs):
        # For each batch in the dataloader.
        for i, ((cartoon_data, _), (edge_data, _),
                (real_data, _)) in enumerate(
                    zip(cartoon_dataloader, edge_dataloader, real_dataloader)):

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################

            # Reset Discriminator gradient.
            netD.zero_grad()

            # Format batch.
            cartoon_data = cartoon_data.to(device)
            edge_data = edge_data.to(device)
            real_data = real_data.to(device)

            # Generate image
            generated_data = netG(real_data)

            # Forward pass all batches through D.
            cartoon_pred = netD(cartoon_data)  #.view(-1)
            edge_pred = netD(edge_data)  #.view(-1)
            generated_pred = netD(generated_data)  #.view(-1)

            print(generated_data.is_cuda, real_data.is_cuda)

            # Calculate discriminator loss on all batches.
            errD = adv_loss(cartoon_pred, generated_pred, edge_pred)

            # Calculate gradients for D in backward pass
            errD.backward()
            D_x = cartoon_pred.mean().item()  # Should be close to 1

            # Update D
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################

            # Reset Generator gradient.
            netG.zero_grad()

            # Since we just updated D, perform another forward pass of all-fake batch through D
            generated_pred = netD(generated_data)  #.view(-1)

            # Calculate G's loss based on this output
            print(generated_data.is_cuda, real_data.is_cuda)
            print("generated_pred:", generated_pred.is_cuda, "cartoon_labels:",
                  cartoon_labels.is_cuda)
            errG = BCE_loss(generated_pred, cartoon_labels) + content_loss(
                generated_data, real_data)

            # Calculate gradients for G
            errG.backward()

            D_G_z2 = generated_pred.mean().item()  # Should be close to 1

            # Update G
            optimizerG.step()

            # ---------------------------------------------------------------------------------------- #

            # Output training stats
            if i % 50 == 0:
                print(
                    '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                    % (epoch, epochs, i, len(real_dataloader), errD.item(),
                       errG.item(), D_x, D_G_z1, D_G_z2))

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Check how the generator is doing by saving G's output on tracked_images
            if (iters % 500 == 0) or ((epoch == epochs - 1) and
                                      (i == len(dataloader) - 1)):
                with torch.no_grad():
                    fake = netG(tracked_images).detach().cpu()
                img_list.append(
                    vutils.make_grid(fake, padding=2, normalize=True))

            iters += 1