def eval(self):
        psnr_lst, ssim_lst, bce_lst = list(), list(), list()
        with torch.no_grad():
            for batch_idx, (imgs, _) in enumerate(self.image_loader):
                imgs = linear_scaling(imgs.float().cuda())
                batch_size, channels, h, w = imgs.size()

                masks = torch.from_numpy(self.mask_generator.generate(
                    h, w)).repeat([batch_size, 1, 1, 1]).float().cuda()
                smooth_masks = self.mask_smoother(1 - masks) + masks
                smooth_masks = torch.clamp(smooth_masks, min=0., max=1.)

                cont_imgs, _ = next(iter(self.cont_image_loader))
                cont_imgs = linear_scaling(cont_imgs.float().cuda())
                if cont_imgs.size(0) != imgs.size(0):
                    cont_imgs = cont_imgs[:imgs.size(0)]

                masked_imgs = cont_imgs * smooth_masks + imgs * (1. -
                                                                 smooth_masks)
                pred_masks, neck = self.mpn(masked_imgs)
                masked_imgs_embraced = masked_imgs * (1. - pred_masks)
                output = self.rin(masked_imgs_embraced, pred_masks, neck)
                output = torch.clamp(output, max=1., min=0.)

                unknown_pixel_ratio = torch.sum(masks.view(batch_size, -1),
                                                dim=1).mean() / (h * w)
                bce = self.BCE(
                    torch.sigmoid(pred_masks), masks,
                    torch.tensor(
                        [1 - unknown_pixel_ratio,
                         unknown_pixel_ratio])).item()
                bce_lst.append(bce)

                ssim = self.SSIM(255. * linear_unscaling(imgs),
                                 255. * output).item()
                ssim_lst.append(ssim)

                psnr = self.PSNR(linear_unscaling(imgs), output).item()
                psnr_lst.append(psnr)

                log.info("{}/{}\tBCE: {}\tSSIM: {}\tPSNR: {}".format(
                    batch_idx, len(self.image_loader), round(bce, 3),
                    round(ssim, 3), round(psnr, 3)))

        results = {
            "Dataset": self.opt.DATASET.NAME,
            "PSNR": np.mean(psnr_lst),
            "SSIM": np.mean(ssim_lst),
            "BCE": np.mean(bce_lst)
        }
        with open(os.path.join(self.opt.TEST.OUTPUT_DIR, "metrics.json"),
                  "a+") as f:
            json.dump(results, f)
Beispiel #2
0
    def run(self):
        while self.num_step < self.opt.TRAIN.NUM_TOTAL_STEP:
            self.num_step += 1
            info = " [Step: {}/{} ({}%)] ".format(self.num_step, self.opt.TRAIN.NUM_TOTAL_STEP, 100 * self.num_step / self.opt.TRAIN.NUM_TOTAL_STEP)

            imgs, y_imgs, labels = next(iter(self.image_loader))
            imgs = linear_scaling(imgs.float().cuda())
            y_imgs = y_imgs.float().cuda()
            labels = labels.cuda()

            for _ in range(self.opt.MODEL.D.NUM_CRITICS):
                d_loss = self.train_D(imgs, y_imgs)
            info += "D Loss: {} ".format(d_loss)

            g_loss, output = self.train_G(imgs, y_imgs, labels)
            info += "G Loss: {} ".format(g_loss)

            if self.num_step % self.opt.TRAIN.LOG_INTERVAL == 0:
                log.info(info)

            if self.num_step % self.opt.TRAIN.VISUALIZE_INTERVAL == 0:
                idx = self.opt.WANDB.NUM_ROW
                self.wandb.log({"examples": [
                    self.wandb.Image(self.to_pil(y_imgs[idx].cpu()), caption="original_image"),
                    self.wandb.Image(self.to_pil(linear_unscaling(imgs[idx]).cpu()), caption="filtered_image"),
                    self.wandb.Image(self.to_pil(torch.clamp(output, min=0., max=1.)[idx].cpu()), caption="output")
                ]}, commit=False)
            self.wandb.log({})
            if self.num_step % self.opt.TRAIN.SAVE_INTERVAL == 0 and self.num_step != 0:
                self.do_checkpoint(self.num_step)
def filter_removal(img):
    arr = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
    arr = torch.tensor(arr).float() / 255.
    arr = linear_scaling(arr)
    with torch.no_grad():
        feat = vgg16(arr)
        out, _ = net(arr, feat)
        out = torch.clamp(out, max=1., min=0.)
        return out.squeeze(0).permute(1, 2, 0).numpy()
Beispiel #4
0
    def eval(self):
        psnr_lst, ssim_lst, lpips_lst = list(), list(), list()
        with torch.no_grad():
            all_preds, all_targets = torch.tensor([]), torch.tensor([])
            for batch_idx, (imgs, y_imgs) in enumerate(self.image_loader):
                imgs = linear_scaling(torch.cat(imgs, dim=0).float().cuda())
                y_imgs = torch.cat(y_imgs, dim=0).float().cuda()
                y = torch.arange(0, len(self.classes)).cuda()
                all_targets = torch.cat((all_targets, y.float().cpu()), dim=0)

                vgg_feat = self.vgg16(imgs)
                output, aux = self.net(imgs, vgg_feat)
                output = torch.clamp(output, max=1., min=0.)
                y_pred = torch.argmax(self.mlp(aux),
                                      dim=-1) if self.mlp is not None else None
                all_preds = torch.cat(
                    (all_preds, y_pred.float().cpu()),
                    dim=0) if y_pred is not None else all_preds

                print(all_preds.size())

                # ssim = self.SSIM(255. * y_imgs, 255. * output).item()
                # ssim_lst.append(ssim)
                #
                # psnr = self.PSNR(y_imgs, output).item()
                # psnr_lst.append(psnr)
                #
                # lpps = lpips(y_imgs, output, net_type='alex', version='0.1').item() / len(y_imgs)  # TODO ?? not sure working
                # lpips_lst.append(lpps)

                # batch_accuracy = round(torch.mean(torch.tensor(y == y_pred.clone().detach()).float()).item() * 100., 2)
                # log.info("{}/{}\tLPIPS: {}\tSSIM: {}\tPSNR: {}\tImage Accuracy: {}".format(batch_idx+1, len(self.image_loader), round(lpps, 3),
                #                                                                            round(ssim, 3), round(psnr, 3), batch_accuracy))

                # os.makedirs(os.path.join(self.output_dir, "images"), exist_ok=True)
                # for i, (y_img, img, out) in enumerate(zip(y_imgs.cpu(), linear_unscaling(imgs).cpu(), output.cpu())):
                #     self.to_pil(y_img).save(os.path.join(self.output_dir, "images", "{}_{}_real_A.png".format(batch_idx, i)))
                #     self.to_pil(img).save(os.path.join(self.output_dir, "images", "{}_{}_fake_B.png".format(batch_idx, i)))
                #     self.to_pil(out).save(os.path.join(self.output_dir, "images", "{}_{}_real_B.png".format(batch_idx, i)))

        if len(all_preds) > 0:
            acc = round((torch.sum(all_preds == all_targets).float() /
                         len(all_preds)).item(), 3) * 100
            self.plot_cm(all_targets, all_preds,
                         list(range(len(self.classes))))
            # self.plot_confusion_matrix(all_targets, all_preds, self.classes)
        else:
            acc = "None"
        results = {
            "Dataset": self.opt.DATASET.NAME,
            "PSNR": np.mean(psnr_lst),
            "SSIM": np.mean(ssim_lst),
            "LPIPS": np.mean(lpips_lst),
            "Accuracy": acc
        }
        log.info(results)
 def paste_facade(self, x, c_img_id):
     resizer = transforms.Resize((self.opt.DATASET.SIZE // 8))
     facade, _ = self.cont_image_loader.dataset.__getitem__(c_img_id)
     facade = linear_scaling(self.tensorize(resizer(self.to_pil(facade))))
     coord_x, coord_y = np.random.randint(self.opt.DATASET.SIZE -
                                          self.opt.DATASET.SIZE // 8,
                                          size=(2, ))
     x_scaled = copy.deepcopy(x)
     x_scaled[:, :, coord_x:coord_x + facade.size(1),
              coord_y:coord_y + facade.size(2)] = facade
     masks = torch.zeros(
         (1, 1, self.opt.DATASET.SIZE, self.opt.DATASET.SIZE)).cuda()
     masks[:, :, coord_x:coord_x + facade.size(1),
           coord_y:coord_y + facade.size(2)] = torch.ones_like(facade[0])
     smooth_masks = self.mask_smoother(1 - masks) + masks
     smooth_masks = torch.clamp(smooth_masks, min=0., max=1.)
     return x_scaled, smooth_masks
 def swap_faces(self, x, c_img_id):
     center_cropper = transforms.CenterCrop(
         (self.opt.DATASET.SIZE // 2, self.opt.DATASET.SIZE // 2))
     c_x, _ = self.cont_image_loader.dataset.__getitem__(c_img_id)
     crop = linear_scaling(
         self.tensorize(center_cropper(self.to_pil(linear_unscaling(c_x)))))
     coord_x = coord_y = (self.opt.DATASET.SIZE -
                          self.opt.DATASET.SIZE // 2) // 2
     x_scaled = copy.deepcopy(linear_unscaling(x))
     x_scaled[:, :, coord_x:coord_x + self.opt.DATASET.SIZE // 2,
              coord_y:coord_y + self.opt.DATASET.SIZE // 2] = crop
     masks = torch.zeros(
         (1, 1, self.opt.DATASET.SIZE, self.opt.DATASET.SIZE)).cuda()
     masks[:, :, coord_x:coord_x + self.opt.DATASET.SIZE // 2,
           coord_y:coord_y + self.opt.DATASET.SIZE // 2] = torch.ones_like(
               crop[0])
     smooth_masks = self.mask_smoother(1 - masks) + masks
     smooth_masks = torch.clamp(smooth_masks, min=0., max=1.)
     return x_scaled, smooth_masks
    def do_ablation(self,
                    mode=None,
                    img_id=None,
                    c_img_id=None,
                    color=None,
                    output_dir=None):
        mode = self.opt.TEST.MODE if mode is None else mode
        assert mode in range(1, 9)
        img_id = self.opt.TEST.IMG_ID if img_id is None else img_id
        assert img_id < len(self.image_loader.dataset)
        c_img_id = self.opt.TEST.C_IMG_ID if c_img_id is None else c_img_id
        assert c_img_id < len(self.cont_image_loader.dataset)
        color = self.opt.TEST.BRUSH_COLOR if color is None else color
        assert str(color).upper() in list(COLORS.keys())
        output_dir = os.path.join(
            self.opt.TEST.OUTPUT_DIR,
            self.ablation_map[mode]) if output_dir is None else output_dir
        # output_dir = os.path.join(self.opt.TEST.OUTPUT_DIR, str(mode), "{}_{}".format(img_id, c_img_id)) if output_dir is None else output_dir
        os.makedirs(output_dir, exist_ok=True)

        x, _ = self.image_loader.dataset.__getitem__(img_id)
        x = linear_scaling(x.unsqueeze(0).cuda())
        batch_size, channels, h, w = x.size()
        with torch.no_grad():
            masks = torch.cat([
                torch.from_numpy(self.mask_generator.generate(h, w))
                for _ in range(batch_size)
            ],
                              dim=0).float().cuda()
            smooth_masks = self.mask_smoother(1 - masks) + masks
            smooth_masks = torch.clamp(smooth_masks, min=0., max=1.)

            if mode == 1:  # contaminant image
                c_x, _ = self.cont_image_loader.dataset.__getitem__(c_img_id)
                c_x = c_x.unsqueeze(0).cuda()
            elif mode == 2:  # random brush strokes with noise
                c_x = torch.rand_like(x)
            elif mode == 3:  # random brush strokes with different colors
                brush = torch.tensor(list(
                    COLORS["{}".format(color).upper()])).unsqueeze(
                        0).unsqueeze(-1).unsqueeze(-1).cuda()
                c_x = torch.ones_like(x) * brush
            elif mode == 4:  # real occlusions
                c_x = linear_unscaling(x)
            elif mode == 5:  # graffiti
                c_x, smooth_masks = self.put_graffiti()
            elif mode == 6:  # facades (i.e. resize whole c_img to 64x64, paste to a random location of img)
                c_x, smooth_masks = self.paste_facade(x, c_img_id)
                c_x = linear_unscaling(c_x)
            elif mode == 7:  # words (i.e. write text with particular font size and color)
                c_x, smooth_masks = self.put_text(x, color)
            else:  # face swap  (i.e. 64x64 center crop from c_img, paste to the center of img)
                c_x, smooth_masks = self.swap_faces(x, c_img_id)

            c_x = linear_scaling(c_x)
            masked_imgs = c_x * smooth_masks + x * (1. - smooth_masks)

            pred_masks, neck = self.mpn(masked_imgs)
            masked_imgs_embraced = masked_imgs * (
                1. - pred_masks) + torch.ones_like(masked_imgs) * pred_masks
            output = self.rin(masked_imgs_embraced, pred_masks, neck)

            vis_output = torch.cat([
                linear_unscaling(x).squeeze(0).cpu(),
                linear_unscaling(c_x).squeeze(0).cpu(),
                smooth_masks.squeeze(0).repeat(3, 1, 1).cpu(),
                linear_unscaling(masked_imgs).squeeze(0).cpu(),
                linear_unscaling(masked_imgs_embraced).squeeze(0).cpu(),
                pred_masks.squeeze(0).repeat(3, 1, 1).cpu(),
                torch.clamp(output.squeeze(0), max=1., min=0.).cpu()
            ],
                                   dim=-1)
            self.to_pil(vis_output).save(
                os.path.join(output_dir,
                             "output_{}_{}.png".format(img_id, c_img_id)))
    def infer(self,
              img_path,
              cont_path=None,
              mode=None,
              color=None,
              text=None,
              mask_path=None,
              gt_path=None,
              output_dir=None):
        mode = self.opt.TEST.MODE if mode is None else mode
        text = self.opt.TEST.TEXT if text is None else text

        with torch.no_grad():
            im = Image.open(img_path).convert("RGB")
            im = im.resize((self.opt.DATASET.SIZE, self.opt.DATASET.SIZE))
            im_t = linear_scaling(
                transforms.ToTensor()(im).unsqueeze(0).cuda())

            if gt_path is not None:
                gt = Image.open(gt_path).convert("RGB")
                gt = gt.resize((self.opt.DATASET.SIZE, self.opt.DATASET.SIZE))

            if mask_path is None:
                masks = torch.from_numpy(
                    self.mask_generator.generate(
                        self.opt.DATASET.SIZE,
                        self.opt.DATASET.SIZE)).float().cuda()
            else:
                masks = Image.open(mask_path).convert("L")
                masks = masks.resize(
                    (self.opt.DATASET.SIZE, self.opt.DATASET.SIZE))
                masks = self.tensorize(masks).unsqueeze(0).float().cuda()

            if cont_path is not None:
                assert mode in [1, 5, 6, 7, 8]
                c_im = Image.open(cont_path).convert("RGB")
                c_im = c_im.resize(
                    (self.opt.DATASET.SIZE, self.opt.DATASET.SIZE))
                if mode == 6:
                    c_im = c_im.resize((self.opt.DATASET.SIZE // 8,
                                        self.opt.DATASET.SIZE // 8))
                    c_im_t = self.tensorize(c_im).unsqueeze(0).cuda()
                    r_c_im_t = torch.zeros((1, 3, self.opt.DATASET.SIZE,
                                            self.opt.DATASET.SIZE)).cuda()
                    masks = torch.zeros((1, 1, self.opt.DATASET.SIZE,
                                         self.opt.DATASET.SIZE)).cuda()
                    for i in range(1):
                        coord_x, coord_y = np.random.randint(
                            self.opt.DATASET.SIZE - self.opt.DATASET.SIZE // 8,
                            size=(2, ))
                        r_c_im_t[:, :, coord_x:coord_x + c_im_t.size(2),
                                 coord_y:coord_y + c_im_t.size(3)] = c_im_t
                        masks[:, :, coord_x:coord_x + c_im_t.size(2),
                              coord_y:coord_y +
                              c_im_t.size(3)] = torch.ones_like(c_im_t[0, 0])
                    c_im_t = linear_scaling(r_c_im_t)
                elif mode == 7:
                    mask = self.to_pil(
                        torch.zeros(
                            (self.opt.DATASET.SIZE, self.opt.DATASET.SIZE)))
                    d = ImageDraw.Draw(c_im)
                    d_m = ImageDraw.Draw(mask)
                    font = ImageFont.truetype(self.opt.TEST.FONT,
                                              self.opt.TEST.FONT_SIZE)
                    font_w, font_h = d.textsize(text, font=font)
                    c_w = (self.opt.DATASET.SIZE - font_w) // 2
                    c_h = (self.opt.DATASET.SIZE - font_h) // 2
                    d.text((c_w, c_h),
                           text,
                           font=font,
                           fill=tuple([
                               int(a * 255)
                               for a in COLORS["{}".format(color).upper()]
                           ]))
                    d_m.text((c_w, c_h), text, font=font, fill=255)
                    masks = self.tensorize(mask).unsqueeze(0).float().cuda()
                    c_im_t = linear_scaling(self.tensorize(c_im).cuda())
                elif mode == 8:
                    center_cropper = transforms.CenterCrop(
                        (self.opt.DATASET.SIZE // 2,
                         self.opt.DATASET.SIZE // 2))
                    crop = self.tensorize(center_cropper(c_im))
                    coord_x = coord_y = (self.opt.DATASET.SIZE - 128) // 2
                    r_c_im_t = torch.zeros((1, 3, self.opt.DATASET.SIZE,
                                            self.opt.DATASET.SIZE)).cuda()
                    r_c_im_t[:, :, coord_x:coord_x + 128,
                             coord_y:coord_y + 128] = crop
                    if mask_path is None:
                        tmp = kornia.resize(masks, self.opt.DATASET.SIZE // 2)
                        masks = torch.zeros((1, 1, self.opt.DATASET.SIZE,
                                             self.opt.DATASET.SIZE)).cuda()
                        masks[:, :,
                              coord_x:coord_x + self.opt.DATASET.SIZE // 2,
                              coord_y:coord_y +
                              self.opt.DATASET.SIZE // 2] = tmp
                        tmp = kornia.hflip(tmp)
                        masks[:, :,
                              coord_x:coord_x + self.opt.DATASET.SIZE // 2,
                              coord_y:coord_y +
                              self.opt.DATASET.SIZE // 2] += tmp
                        # tmp = kornia.vflip(tmp)
                        # masks[:, :, coord_x:coord_x + self.opt.DATASET.SIZE // 2, coord_y:coord_y + self.opt.DATASET.SIZE // 2] += tmp
                        masks = torch.clamp(masks, min=0., max=1.)
                    c_im_t = linear_scaling(r_c_im_t)
                else:
                    c_im_t = linear_scaling(
                        transforms.ToTensor()(c_im).unsqueeze(0).cuda())
            else:
                assert mode in [2, 3, 4]
                if mode == 2:
                    c_im_t = linear_scaling(torch.rand_like(im_t))
                elif mode == 3:
                    color = self.opt.TEST.BRUSH_COLOR if color is None else color
                    brush = torch.tensor(
                        list(COLORS["{}".format(color).upper()])).unsqueeze(
                            0).unsqueeze(-1).unsqueeze(-1).cuda()
                    c_im_t = linear_scaling(torch.ones_like(im_t) * brush)
                elif mode == 4:
                    c_im_t = im_t

            if (mask_path is None or mode == 5) and mode != 8:
                smooth_masks = self.mask_smoother(1 - masks) + masks
                smooth_masks = torch.clamp(smooth_masks, min=0., max=1.)
            else:
                smooth_masks = masks

            masked_imgs = c_im_t * smooth_masks + im_t * (1. - smooth_masks)
            pred_masks, neck = self.mpn(masked_imgs)
            pred_masks = pred_masks if mode != 8 else torch.clamp(
                pred_masks * smooth_masks, min=0., max=1.)
            masked_imgs_embraced = masked_imgs * (1. - pred_masks)
            output = self.rin(masked_imgs_embraced, pred_masks, neck)
            output = torch.clamp(output, max=1., min=0.)

            if output_dir is not None:
                # output_dir = os.path.join(output_dir, self.ablation_map[mode])
                # os.makedirs(output_dir, exist_ok=True)
                if mode == 8:
                    self.to_pil(
                        torch.cat([
                            linear_unscaling(im_t).squeeze().cpu(),
                            self.tensorize(c_im).squeeze().cpu(),
                            linear_unscaling(masked_imgs).squeeze().cpu(),
                            output.squeeze().cpu()
                        ],
                                  dim=2)).save(
                                      os.path.join(
                                          output_dir, "out{}_{}_{}".format(
                                              mode, color,
                                              img_path.split("/")[-1])))
                else:
                    self.to_pil(
                        torch.cat([
                            linear_unscaling(masked_imgs).squeeze().cpu(),
                            output.squeeze().cpu()
                        ],
                                  dim=1)).save(
                                      os.path.join(
                                          output_dir, "out{}_{}_{}".format(
                                              mode, color,
                                              img_path.split("/")[-1])))
            else:
                self.to_pil(output.squeeze().cpu()).show()
                self.to_pil(pred_masks.squeeze().cpu()).show()
                self.to_pil(
                    linear_unscaling(masked_imgs).squeeze().cpu()).show()
                self.to_pil(smooth_masks.squeeze().cpu()).show()
                self.to_pil(linear_unscaling(im_t).squeeze().cpu()).show()
                if gt_path is not None:
                    gt.show()
    def run(self):
        while self.num_step < self.opt.TRAIN.NUM_TOTAL_STEP:
            self.num_step += 1
            info = " [Step: {}/{} ({}%)] ".format(
                self.num_step, self.opt.TRAIN.NUM_TOTAL_STEP,
                100 * self.num_step / self.opt.TRAIN.NUM_TOTAL_STEP)

            imgs, y_imgs = next(iter(self.image_loader))
            imgs = linear_scaling(imgs.float().cuda())
            y_imgs = y_imgs.float().cuda()

            for _ in range(self.opt.MODEL.D.NUM_CRITICS):
                self.optimizer_discriminator.zero_grad()

                pred_masks, neck = self.mpn(imgs)
                output = self.rin(imgs, pred_masks, neck)

                real_validity = self.discriminator(y_imgs).mean()
                fake_validity = self.discriminator(output.detach()).mean()
                gp = compute_gradient_penalty(self.discriminator, output.data,
                                              y_imgs.data)

                d_loss = -real_validity + fake_validity + self.opt.OPTIM.GP * gp
                d_loss.backward()
                self.optimizer_discriminator.step()

                self.wandb.log(
                    {
                        "real_validity": -real_validity.item(),
                        "fake_validity": fake_validity.item(),
                        "gp": gp.item()
                    },
                    commit=False)

            self.optimizer_joint.zero_grad()
            pred_masks, neck = self.mpn(imgs)
            if self.opt.MODEL.RIN.EMBRACE:
                x_embraced = imgs.detach() * (1 - pred_masks.detach())
                output = self.rin(x_embraced, pred_masks.detach(),
                                  neck.detach())
            else:
                output = self.rin(imgs, pred_masks.detach(), neck.detach())
            recon_loss = self.reconstruction_loss(output, y_imgs)
            sem_const_loss = self.semantic_consistency_loss(output, y_imgs)
            tex_const_loss = self.texture_consistency_loss(output, y_imgs)
            adv_loss = -self.discriminator(output).mean()

            g_loss = self.opt.OPTIM.RECON * recon_loss + \
                     self.opt.OPTIM.SEMANTIC * sem_const_loss + \
                     self.opt.OPTIM.TEXTURE * tex_const_loss + \
                     self.opt.OPTIM.ADVERSARIAL * adv_loss

            g_loss.backward()
            self.optimizer_joint.step()
            self.wandb.log(
                {
                    "recon_loss": recon_loss.item(),
                    "sem_const_loss": sem_const_loss.item(),
                    "tex_const_loss": tex_const_loss.item(),
                    "adv_loss": adv_loss.item()
                },
                commit=False)

            info += "D Loss: {} ".format(d_loss)
            info += "G Loss: {} ".format(g_loss)

            if self.num_step % self.opt.MODEL.RAINDROP_LOG_INTERVAL == 0:
                log.info(info)

            if self.num_step % self.opt.MODEL.RAINDROP_VISUALIZE_INTERVAL == 0:
                idx = self.opt.WANDB.NUM_ROW
                self.wandb.log(
                    {
                        "examples": [
                            self.wandb.Image(self.to_pil(y_imgs[idx].cpu()),
                                             caption="original_image"),
                            self.wandb.Image(self.to_pil(
                                linear_unscaling(imgs[idx]).cpu()),
                                             caption="masked_image"),
                            self.wandb.Image(self.to_pil(
                                pred_masks[idx].cpu()),
                                             caption="predicted_masks"),
                            self.wandb.Image(self.to_pil(
                                torch.clamp(output, min=0.,
                                            max=1.)[idx].cpu()),
                                             caption="output")
                        ]
                    },
                    commit=False)
            self.wandb.log({})
            if self.num_step % self.opt.MODEL.RAINDROP_SAVE_INTERVAL == 0 and self.num_step != 0:
                self.do_checkpoint(self.num_step)
    def run(self):
        while self.num_step < self.opt.TRAIN.NUM_TOTAL_STEP:
            self.num_step += 1
            info = " [Step: {}/{} ({}%)] ".format(
                self.num_step, self.opt.TRAIN.NUM_TOTAL_STEP,
                100 * self.num_step / self.opt.TRAIN.NUM_TOTAL_STEP)

            imgs, _ = next(iter(self.image_loader))
            y_imgs = imgs.float().cuda()
            imgs = linear_scaling(imgs.float().cuda())
            batch_size, channels, h, w = imgs.size()

            masks = torch.from_numpy(self.mask_generator.generate(
                h, w)).repeat([batch_size, 1, 1, 1]).float().cuda()

            cont_imgs, _ = next(iter(self.cont_image_loader))
            cont_imgs = linear_scaling(cont_imgs.float().cuda())
            if cont_imgs.size(0) != imgs.size(0):
                cont_imgs = cont_imgs[:imgs.size(0)]

            smooth_masks = self.mask_smoother(1 - masks) + masks
            smooth_masks = torch.clamp(smooth_masks, min=0., max=1.)

            masked_imgs = cont_imgs * smooth_masks + imgs * (1. - smooth_masks)
            self.unknown_pixel_ratio = torch.sum(masks.view(batch_size, -1),
                                                 dim=1).mean() / (h * w)

            for _ in range(self.opt.MODEL.D.NUM_CRITICS):
                d_loss = self.train_D(masked_imgs, masks, y_imgs)
            info += "D Loss: {} ".format(d_loss)

            m_loss, g_loss, pred_masks, output = self.train_G(
                masked_imgs, masks, y_imgs)
            info += "M Loss: {} G Loss: {} ".format(m_loss, g_loss)

            if self.num_step % self.opt.TRAIN.LOG_INTERVAL == 0:
                log.info(info)

            if self.num_step % self.opt.TRAIN.VISUALIZE_INTERVAL == 0:
                idx = self.opt.WANDB.NUM_ROW
                self.wandb.log(
                    {
                        "examples": [
                            self.wandb.Image(self.to_pil(y_imgs[idx].cpu()),
                                             caption="original_image"),
                            self.wandb.Image(self.to_pil(
                                linear_unscaling(cont_imgs[idx]).cpu()),
                                             caption="contaminant_image"),
                            self.wandb.Image(self.to_pil(
                                linear_unscaling(masked_imgs[idx]).cpu()),
                                             caption="masked_image"),
                            self.wandb.Image(self.to_pil(masks[idx].cpu()),
                                             caption="original_masks"),
                            self.wandb.Image(self.to_pil(
                                smooth_masks[idx].cpu()),
                                             caption="smoothed_masks"),
                            self.wandb.Image(self.to_pil(
                                pred_masks[idx].cpu()),
                                             caption="predicted_masks"),
                            self.wandb.Image(self.to_pil(
                                torch.clamp(output, min=0.,
                                            max=1.)[idx].cpu()),
                                             caption="output")
                        ]
                    },
                    commit=False)
            self.wandb.log({})
            if self.num_step % self.opt.TRAIN.SAVE_INTERVAL == 0 and self.num_step != 0:
                self.do_checkpoint(self.num_step)