示例#1
0
def train():
    torch.manual_seed(1337)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Config
    batch_size = 16
    image_size = 256
    learning_rate = 1e-3
    beta1, beta2 = (.5, .99)
    weight_decay = 1e-3
    epochs = 10

    # Dataloaders
    real_dataloader = get_dataloader(
        "./datasets/real_images/flickr_nuneaton/", size=image_size, bs=batch_size, trfs=get_no_aug_transform())

    # Lists to keep track of progress
    G_losses = []
    iters = 0

    tracked_images = next(iter(real_dataloader)).to(device)
    vutils.save_image(unnormalize(tracked_images),
                      "images/org.png", padding=2, normalize=True)

    # Models
    netG = Generator().to(device)

    scaler = torch.cuda.amp.GradScaler()

    optimizerG = AdamW(netG.parameters(), lr=learning_rate,
                       betas=(beta1, beta2), weight_decay=weight_decay)

    # Loss functions
    content_loss = ContentLoss().to(device)

    print("Starting Training Loop...")
    # For each epoch.
    for epoch in range(epochs):
        # For each batch in the dataloader.
        for i, real_data, in enumerate(tqdm(real_dataloader, desc=f"Training epoch {epoch}")):

            ############################
            # (1) Pre-train G
            ###########################

            # Reset Discriminator gradient.
            netG.zero_grad()

            # Format batch.
            real_data = real_data.to(device)

            with torch.cuda.amp.autocast():
                # Generate image
                generated_data = netG(real_data)

                # Calculate discriminator loss on all batches.
                errG = content_loss(generated_data, real_data)

            # Calculate gradients for G
            scaler.scale(errG).backward()

            # Update G
            scaler.step(optimizerG)

            scaler.update()

            # ---------------------------------------------------------------------------------------- #

            # Save Losses for plotting later
            G_losses.append(errG.item())

            # Check how the generator is doing by saving G's output on tracked_images
            if iters % 200 == 0:
                with torch.no_grad():
                    fake = netG(tracked_images).detach().cpu()
                vutils.save_image(unnormalize(
                    fake), f"images/{epoch}_{i}.png", padding=2)
                torch.save(netG, f"checkpoints/pretrained_netG_e{epoch}_i{iters}_l{errG.item()}.pth")

            iters += 1

    torch.save(netG.state_dict(), f"checkpoints/pretrained_netG_e{epoch}_i{iters}_l{errG.item()}.pth")
示例#2
0
def train():
    torch.manual_seed(1337)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Config
    batch_size = 32
    image_size = 256
    learning_rate = 1e-4
    beta1, beta2 = (.5, .99)
    weight_decay = 1e-4
    epochs = 1000

    # Models
    netD = Discriminator().to(device)
    netG = Generator().to(device)
    # Here you should load the pretrained G
    netG.load_state_dict(torch.load("./checkpoints/pretrained_netG.pth").state_dict())

    optimizerD = AdamW(netD.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay)
    optimizerG = AdamW(netG.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay)

    scaler = torch.cuda.amp.GradScaler()

    # Labels
    cartoon_labels = torch.ones (batch_size, 1, image_size // 4, image_size // 4).to(device)
    fake_labels    = torch.zeros(batch_size, 1, image_size // 4, image_size // 4).to(device)

    # Loss functions
    content_loss = ContentLoss().to(device)
    adv_loss     = AdversialLoss(cartoon_labels, fake_labels).to(device)
    BCE_loss     = nn.BCEWithLogitsLoss().to(device)

    # Dataloaders
    real_dataloader    = get_dataloader("./datasets/real_images/flickr30k_images/",           size = image_size, bs = batch_size)
    cartoon_dataloader = get_dataloader("./datasets/cartoon_images_smoothed/Studio Ghibli",   size = image_size, bs = batch_size, trfs=get_pair_transforms(image_size))

    # --------------------------------------------------------------------------------------------- #
    # Training Loop

    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    tracked_images = next(iter(real_dataloader)).to(device)

    print("Starting Training Loop...")
    # For each epoch.
    for epoch in range(epochs):
        print("training epoch ", epoch)
        # For each batch in the dataloader.
        for i, (cartoon_edge_data, real_data) in enumerate(zip(cartoon_dataloader, real_dataloader)):

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            
            # Reset Discriminator gradient.
            netD.zero_grad()
            for param in netD.parameters():
                param.requires_grad = True

            # Format batch.
            cartoon_data   = cartoon_edge_data[:, :, :, :image_size].to(device)
            edge_data      = cartoon_edge_data[:, :, :, image_size:].to(device)
            real_data      = real_data.to(device)

            with torch.cuda.amp.autocast():
                # Generate image
                generated_data = netG(real_data)

                # Forward pass all batches through D.
                cartoon_pred   = netD(cartoon_data)      #.view(-1)
                edge_pred      = netD(edge_data)         #.view(-1)
                generated_pred = netD(generated_data.detach())    #.view(-1)

                # Calculate discriminator loss on all batches.
                errD = adv_loss(cartoon_pred, generated_pred, edge_pred)
            
            # Calculate gradients for D in backward pass
            scaler.scale(errD).backward()
            D_x = cartoon_pred.mean().item() # Should be close to 1

            # Update D
            scaler.step(optimizerD)


            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            
            # Reset Generator gradient.
            netG.zero_grad()
            for param in netD.parameters():
                param.requires_grad = False

            with torch.cuda.amp.autocast():
                # Since we just updated D, perform another forward pass of all-fake batch through D
                generated_pred = netD(generated_data) #.view(-1)

                # Calculate G's loss based on this output
                errG = BCE_loss(generated_pred, cartoon_labels) + content_loss(generated_data, real_data)

            # Calculate gradients for G
            scaler.scale(errG).backward()

            D_G_z2 = generated_pred.mean().item() # Should be close to 1
            
            # Update G
            scaler.step(optimizerG)

            scaler.update()
            
            # ---------------------------------------------------------------------------------------- #

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Check how the generator is doing by saving G's output on tracked_images
            if iters % 200 == 0:
                with torch.no_grad():
                    fake = netG(tracked_images)
                vutils.save_image(unnormalize(fake), f"images/{epoch}_{i}.png", padding=2)
                with open("images/log.txt", "a+") as f:
                    f.write(f"{datetime.now().isoformat(' ', 'seconds')}\tD: {np.mean(D_losses)}\tG: {np.mean(G_losses)}\n")
                D_losses = []
                G_losses = []

            if iters % 1000 == 0:
                torch.save(netG.state_dict(), f"checkpoints/netG_e{epoch}_i{iters}_l{errG.item()}.pth")
                torch.save(netD.state_dict(), f"checkpoints/netD_e{epoch}_i{iters}_l{errG.item()}.pth")

            iters += 1
                                     align_corners=False)
        input_m2 = torch.cat([morphed_flag * interest_mask, outline], dim=1)
    else:
        input_m2 = torch.cat([flag * interest_mask, outline], dim=1)

    # Second model
    if use_second_model:
        output = M2_model(input_m2)
        output = output * interest_mask
    else:
        output = input_m2[:, :-1]

    # Add to data structure
    for cnam, outp, outl, mask, fnam in zip(country_name, output, outline,
                                            ball_mask, file_name):
        image = TF.to_pil_image(unnormalize(outp))
        flag = Image.open(input_folder / "flags" /
                          f"{cnam}.png").convert("RGB")

        # Quantize
        if use_quantization:
            image = quantize_pil_image(image, flag, quantize_thresh)

        # Add outline
        image = Image.composite(Image.new("RGB", img_size, (0, 0, 0)), image,
                                TF.to_pil_image(outl))
        # Add transparent background
        image = Image.composite(Image.new("RGBA", img_size, (0, 0, 0, 0)),
                                image.convert("RGBA"),
                                TF.to_pil_image(1 - (mask + outl)))
示例#4
0
    # criterion
    criterion = nn.L1Loss().to(device)

    # optimizer
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=lr,
                                  betas=betas,
                                  weight_decay=wd)

    run_id = '_'.join(
        ["GMM_P2", note,
         datetime.today().strftime('%Y-%m-%d-%H.%M.%S')])
    logger = SummaryWriter(f"./training/logs/{run_id}")

    logger.add_images("test/4_outline", test_outline)
    logger.add_images("test/3_flags", unnormalize(test_flag))
    logger.add_images("test/2_target", unnormalize(test_ball_flag))

    for e in count():
        running_loss = []
        for i, inputs in enumerate(tqdm(dataloader, desc=f"Epoch {e}")):
            global_step = e * len(dataloader) + i

            flag = inputs["flag"].to(device)
            outline = inputs["outline"].to(device)
            target = inputs["ball_flag"].to(device)
            target_mask = inputs["ball_flag_mask"].to(device)
            target = target * target_mask

            grid, theta = model(flag, outline)
                    input = torch.cat([GMM_morph, outline], dim=1)

                    output = G(input)
                    output = output * interest_mask

                    loss_G = critereon(GMM_morph * interest_mask, output,
                                       target)

                    running_loss["loss_G"].append(loss_G.item())

                for key, value in running_loss.items():
                    logger.add_scalar(f"test/{key}", np.mean(value),
                                      global_step)
                running_loss = {k: [] for k in running_loss}

                logger.add_images("test/0_output", unnormalize(output),
                                  global_step)
                logger.add_images("test/2_output_outline",
                                  unnormalize(output) * (1 - outline),
                                  global_step)
                # Log these only once
                if first_run:
                    logger.add_images("test/1_target", unnormalize(target),
                                      global_step)
                    logger.add_images("test/3_target_outline",
                                      unnormalize(target) * (1 - outline),
                                      global_step)
                    logger.add_images("test/4_input_morphed",
                                      unnormalize(GMM_morph), global_step)
                    logger.add_images("test/5_input_outline", outline,
                                      global_step)
def predict(
        outlines, 
        flags, 
        img_size=256,
        use_M1 = True,
        use_M2 = True,
        use_quantization = True,
        quantize_threshold = 0.01,
        use_cuda=True
    ):
    outlines = copy(outlines)
    flags = copy(flags)
    
    # Convert to uniform types
    if type(img_size) == int:
        img_size = (img_size, img_size)

    if type(outlines) != list:
        outlines = [outlines]

    if type(flags) != list:
        flags = [flags]


    # Set device
    device = torch.device("cuda" if torch.cuda.is_available()
                        and use_cuda else "cpu")

    # Set up models
    GMM_model = GMM(*img_size, use_cuda=use_cuda)
    GMM_model.load_state_dict(torch.load(r"../main_weights/GMM.pth"))
    M2_model = smp.Unet(encoder_name="resnet34", encoder_weights="imagenet", decoder_use_batchnorm=True,
                        decoder_attention_type="scse", in_channels=4, classes=3, activation=torch.nn.Tanh)
    M2_model.load_state_dict(torch.load(r"../main_weights/BSM.pth")["G"])
    GMM_model = GMM_model.to(device).eval()
    M2_model = M2_model.to(device).eval()

    masks = []
    uncrop_coords = []

    # Pre-process data
    for i, outline in enumerate(outlines):
        outline = outline.convert("RGBA")

        # For uncropping later
        coords = (img_size[0]//2-outline.width//2, img_size[1]//2-outline.height//2, outline.width, outline.height)
        uncrop_coords.append(coords)

        mask = outline
        ImageDraw.floodfill(mask, (0,0), (0,0,0,0), border=None, thresh=0)
        mask = resize_and_pad_ball(mask, img_size)
        masks.append(TF.to_tensor(mask.split()[-1]))

        outline = resize_and_pad_ball(outline, img_size)
        outline = get_black_mask(outline).astype(int)

        outlines[i] = TF.to_tensor(outline)

    color_flags = []

    # Pre-process data
    for i, flag in enumerate(flags):
        flag = flag.convert("RGBA")
        color_flags.append(flag.convert("RGB"))
        flag = fit_resize(flag, (img_size[0]-15, img_size[1]-15) )
        flag = TF.center_crop(flag, img_size[::-1])
        flag = Image.composite(flag, Image.new("RGB", img_size, (255, 255, 255)), flag)
        flags[i] = TF.to_tensor(flag)

    flags = torch.stack(flags).float().to(device)
    flags = TF.normalize(flags, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    outlines = torch.stack(outlines).float().to(device)
    masks = torch.stack(masks).float().to(device)

    interest_mask = masks - outlines

    # First model
    if use_M1:
        grid, _ = GMM_model(flags, outlines)
        morphed_flag = F.grid_sample(flags, grid, padding_mode="border", align_corners=False)
        input_m2 = torch.cat([morphed_flag*interest_mask, outlines], dim=1)
    else:
        input_m2 = torch.cat([flags*interest_mask, outlines], dim=1)
    
    # Second model
    if use_M2:
        output = M2_model(input_m2)
        output = output*interest_mask
    else:
        output = input_m2[:,:-1]

    finished_images = []

    # Add to data structure
    for outp, outl, mask, color_flag, uc in zip(output, outlines, masks, color_flags, uncrop_coords):
        image = TF.to_pil_image(unnormalize(outp))

        # Quantize
        if use_quantization:
            image = quantize_pil_image(
                image, 
                color_flag,
                quantize_threshold
            )

        # Add outline
        image = Image.composite(
            Image.new("RGB", img_size, (0, 0, 0)), 
            image, 
            TF.to_pil_image(outl)
        )
        # Add transparent background
        image = Image.composite(
            Image.new("RGBA", img_size, (0, 0, 0, 0)),
            image.convert("RGBA"),
            TF.to_pil_image(1-(mask+outl))
        )
        # Uncrop
        background = Image.new("RGBA", uc[-2:], (0, 0, 0, 0))
        background.paste(image, (-uc[0], -uc[1]))
        image = background

        finished_images.append(image)

    return finished_images