Exemple #1
0
    def after_batch_tfm(self, cap, cap_len, img):
        "cap: (bs, 2), cap_len: (bs,), img: (bs, 64, 64, 3)"
        img64 = img.permute(0, 3, 1, 2).float()  # (bs, 3, 64, 64)
        img64 = random_hflip(img64)
        img4 = resize(img64, size=(4, 4))
        img8 = resize(img64, size=(8, 8))
        img16 = resize(img64, size=(16, 16))
        img32 = resize(img64, size=(32, 32))

        img4 = self.normalizer.encode(img4)  # (bs, 3, 4, 4)
        img8 = self.normalizer.encode(img8)  # (bs, 3, 8, 8)
        img16 = self.normalizer.encode(img16)  # (bs, 3, 16, 16)
        img32 = self.normalizer.encode(img32)  # (bs, 3, 32, 32)
        img64 = self.normalizer.encode(img64)  # (bs, 3, 64, 64)
        return cap, cap_len, (img64, img32, img16, img8, img4)
Exemple #2
0
 def after_batch_tfm(self, cap, cap_len, img):
     " cap: (bs, max_seq_len), cap_len: (bs,), img: (bs, 64, 64, 3) "
     img = img.permute(0, 3, 1, 2).float()  # (bs, 3, 64, 64)
     img = random_hflip(img)
     img = resize(img, size=(229, 229))  # (bs, 3, 229, 229)
     img = self.normalizer.encode(img)
     return cap, cap_len, img
Exemple #3
0
 def after_batch_tfm(self, cap, cap_len, img):
     " cap: (bs, max_seq_len), cap_len: (bs,), img: (bs, 256, 256, 3) "
     img = img.permute(0, 3, 1, 2).float()  # (bs, 3, 256, 256)
     img = resize(img, size=(int(256 * 76 / 64),
                             int(256 * 76 / 64)))  # (bs, 3, 304, 304)
     img = random_crop(img)  # (bs, 3, 229, 229)
     img = random_hflip(img)
     img = self.normalizer.encode(img)
     return cap, cap_len, img
Exemple #4
0
 def __getitem__(self, idx):
     """
     Resizes and normalizes tensors for Inception_v3
     :param idx: ``int``, index
     :return: Tensor([C, H, W])
     """
     img = self.data[idx]
     img = resize(img, size=(299, 299))
     img = normalize(img,
                     mean=torch.tensor([0.485, 0.456, 0.406]),
                     std=torch.tensor([0.229, 0.224, 0.225]))
     return img.squeeze(0)
Exemple #5
0
 def apply_transform(self,
                     input: Tensor,
                     params: Dict[str, Tensor],
                     transform: Optional[Tensor] = None) -> Tensor:
     B, C, _, _ = input.shape
     out_size = tuple(params["output_size"][0].tolist())
     out = torch.empty(B,
                       C,
                       *out_size,
                       device=input.device,
                       dtype=input.dtype)
     for i in range(B):
         x1 = int(params["src"][i, 0, 0])
         x2 = int(params["src"][i, 1, 0]) + 1
         y1 = int(params["src"][i, 0, 1])
         y2 = int(params["src"][i, 3, 1]) + 1
         out[i] = resize(
             input[i:i + 1, :, y1:y2, x1:x2],
             out_size,
             interpolation=(self.flags["resample"].name).lower(),
             align_corners=self.flags["align_corners"],
         )
     return out
Exemple #6
0
    def after_batch_tfm(self, cap, cap_len, img):
        " cap: (bs, 25), img: (bs, 256, 256, 3) "
        img256 = img.permute(0, 3, 1, 2).float()  # (bs, 3, 256, 256)
        img256 = resize(img256, size=(int(256 * 76 / 64),
                                      int(256 * 76 / 64)))  # (bs, 3, 304, 304)
        img256 = random_crop(img256)  # (bs, 3, 256, 256)
        img256 = random_hflip(img256)

        img4 = resize(img256, size=(4, 4))
        img8 = resize(img256, size=(8, 8))
        img16 = resize(img256, size=(16, 16))
        img32 = resize(img256, size=(32, 32))
        img64 = resize(img256, size=(64, 64))
        img128 = resize(img256, size=(128, 128))

        img4 = self.normalizer.encode(img4)  # (bs, 3, 4, 4)
        img8 = self.normalizer.encode(img8)  # (bs, 3, 8, 8)
        img16 = self.normalizer.encode(img16)  # (bs, 3, 16, 16)
        img32 = self.normalizer.encode(img32)  # (bs, 3, 32, 32)
        img64 = self.normalizer.encode(img64)  # (bs, 3, 64, 64)
        img128 = self.normalizer.encode(img128)  # (bs, 3, 128, 128)
        img256 = self.normalizer.encode(img256)  # (bs, 3, 256, 256)
        return cap, cap_len, (img256, img128, img64, img32, img16, img8, img4)
Exemple #7
0
def _apply_black_border(x, border_size):
    orig_height = x.shape[2]
    orig_width = x.shape[3]
    x = resize(x,
               (orig_width - 2 * border_size, orig_height - 2 * border_size))
    return nn.ConstantPad2d(border_size, 0)(x)
Exemple #8
0
def train_one_epoch(gen, opt_gen, scaler_gen, dis, opt_dis, scaler_dis,
                    dataloader, metric_logger, device, fixed_noise, epoch,
                    fid_model, fid_score):
    """
    Train one epoch
    :param gen: ``Instance of models.model.Generator``, generator model
    :param opt_gen: ``Instance of torch.optim``, optimizer for generator
    :param scaler_gen: ``torch.cuda.amp.GradScaler()``, gradient scaler for generator
    :param dis: ``Instance of models.model.Discriminator``, discriminator model
    :param opt_dis: ``Instance of torch.optim``, optimizer for discriminator
    :param scaler_dis: ``torch.cuda.amp.GradScaler()``, gradient scaler for discriminator
    :param dataloader: ``Instance of torch.utils.data.DataLoader``, train dataloader
    :param metric_logger: ``Instance of metrics.MetricLogger``, logger
    :param device: ``Instance of torch.device``, cuda device
    :param fixed_noise: ``Tensor([N, 1, Z, 1, 1])``, fixed noise for display result
    :param epoch: ``int``, current epoch
    :param fid_model: model for calculate fid score
    :param fid_score: ``bool``, if True - calculate fid score otherwise no
    """
    loop = tqdm(dataloader, leave=True)
    for batch_idx, real in enumerate(loop):
        cur_batch_size = real.shape[0]
        real = real.to(device)
        real.requires_grad_()
        real_cropped_128 = center_crop(real, size=(128, 128))
        real_128 = resize(real, size=(128, 128))
        noise = torch.randn(cur_batch_size, cfg.Z_DIMENSION, 1, 1).to(device)
        # Train discriminator
        with torch.cuda.amp.autocast():
            with torch.no_grad():
                fake = gen(noise)
            real_fake_logits_real_images, decoded_real_img_part, decoded_real_img = dis(
                DiffAugment(real, policy=cfg.DIFF_AUGMENT_POLICY))
            real_fake_logits_fake_images, _, _ = dis(
                DiffAugment(fake.detach(), policy=cfg.DIFF_AUGMENT_POLICY))
            # maximize divergence between real and fake data
            divergence = hinge_adv_loss(real_fake_logits_real_images,
                                        real_fake_logits_fake_images, device)
            i_recon_loss = reconstruction_loss_mse(real_128, decoded_real_img)
            i_part_recon_loss = reconstruction_loss_mse(
                real_cropped_128, decoded_real_img_part)
            d_loss = divergence + i_recon_loss + i_part_recon_loss

        opt_dis.zero_grad()
        scaler_dis.scale(d_loss).backward()
        scaler_dis.step(opt_dis)
        scaler_dis.update()

        opt_gen.zero_grad()
        noise = torch.randn(cur_batch_size, cfg.Z_DIMENSION, 1, 1).to(device)
        # Train generator
        with torch.cuda.amp.autocast():
            # maximize E[D(G(z))], also we can minimize the negative of that
            fake = gen(noise)
            fake_logits, _, _ = dis(
                DiffAugment(fake, policy=cfg.DIFF_AUGMENT_POLICY))
            g_loss = -torch.mean(fake_logits)

        scaler_gen.scale(g_loss).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()
        # Eval and metrics
        if fid_score:
            if epoch % cfg.FID_FREQ == 0:
                # TODO: test fid on GPU
                fid = evaluate(gen, fid_model, device)
                gen.train()
                metric_logger.log_fid(fid)
        if batch_idx % cfg.LOG_FREQ == 0:
            metric_logger.log(g_loss, d_loss, divergence, i_recon_loss,
                              i_part_recon_loss)
        if batch_idx % cfg.LOG_IMAGE_FREQ == 0:
            with torch.no_grad():
                fixed_fakes = gen(fixed_noise)
                metric_logger.log_image(fixed_fakes,
                                        cfg.NUM_SAMPLES_IMAGES,
                                        epoch,
                                        batch_idx,
                                        normalize=True)

        loop.set_postfix(d_loss=d_loss.item(), g_loss=g_loss.item())
Exemple #9
0
def train_one_epoch_with_gp(gen, opt_gen, crt, opt_crt, dataloader,
                            metric_logger, device, fixed_noise, epoch,
                            fid_model, fid_score):
    """
    Train one epoch with wasserstein distance with gradient penalty
    :param gen: ``Instance of models.model.Generator``, generator model
    :param opt_gen: ``Instance of torch.optim``, optimizer for generator
    :param crt: ``Instance of models.model.Discriminator``, discriminator model
    :param opt_dis: ``Instance of torch.optim``, optimizer for discriminator
    :param dataloader: ``Instance of torch.utils.data.DataLoader``, train dataloader
    :param metric_logger: ``Instance of metrics.MetricLogger``, logger
    :param device: ``Instance of torch.device``, cuda device
    :param fixed_noise: ``Tensor([N, 1, Z, 1, 1])``, fixed noise for display result
    :param epoch: ``int``, current epoch
    :param fid_model: model for calculate fid score
    :param fid_score: ``bool``, if True - calculate fid score otherwise no
    """
    loop = tqdm(dataloader, leave=True)
    for batch_idx, real in enumerate(loop):
        cur_batch_size = real.shape[0]
        real = real.to(device)
        real.requires_grad_()
        real_cropped_128 = center_crop(real, size=(128, 128))
        real_128 = resize(real, size=(128, 128))
        # Train critic
        for _ in range(cfg.CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, cfg.Z_DIMENSION, 1,
                                1).to(device)
            fake = gen(noise)
            real_fake_logits_real_images, decoded_real_img_part, decoded_real_img = crt(
                DiffAugment(real, policy=cfg.DIFF_AUGMENT_POLICY))
            real_fake_logits_fake_images, _, _ = crt(
                DiffAugment(fake, policy=cfg.DIFF_AUGMENT_POLICY))

            divergence = hinge_adv_loss(real_fake_logits_real_images,
                                        real_fake_logits_fake_images, device)
            i_recon_loss = reconstruction_loss_mse(real_128, decoded_real_img)
            i_part_recon_loss = reconstruction_loss_mse(
                real_cropped_128, decoded_real_img_part)
            gp = gradient_penalty(crt, real, fake, device)
            crt_loss = (divergence + i_recon_loss +
                        i_part_recon_loss) + cfg.LAMBDA_GP * gp
            crt.zero_grad()
            crt_loss.backward(retain_graph=True)
            opt_crt.step()

        # Train generator
        fake_logits, _, _ = crt(fake)
        g_loss = -torch.mean(fake_logits)
        gen.zero_grad()
        g_loss.backward()
        opt_gen.step()

        # Eval and metrics
        if fid_score:
            if epoch % cfg.FID_FREQ == 0:
                # TODO: test fid on GPU
                fid = evaluate(gen, fid_model, device)
                gen.train()
                metric_logger.log_fid(fid)
        if batch_idx % cfg.LOG_FREQ == 0:
            metric_logger.log(g_loss, crt_loss, divergence, i_recon_loss,
                              i_part_recon_loss)
        if batch_idx % cfg.LOG_IMAGE_FREQ == 0:
            with torch.no_grad():
                fixed_fakes = gen(fixed_noise)
                metric_logger.log_image(fixed_fakes,
                                        cfg.NUM_SAMPLES_IMAGES,
                                        epoch,
                                        batch_idx,
                                        normalize=True)

        loop.set_postfix(d_loss=crt_loss.item(), g_loss=g_loss.item())