コード例 #1
0
def train(netG,
          netD,
          GANLoss,
          ReconLoss,
          DLoss,
          optG,
          optD,
          dataloader,
          epoch,
          device=device,
          val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    netG.to(device)
    netD.to(device)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        'd_loss': AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    for i, (imgs, masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        #masks = masks['random_free_form']

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()

        imgs, masks = imgs.to(device), masks.to(device)
        imgs = (imgs / 250)
        # mask is 1 on masked region

        coarse_imgs, recon_imgs = netG(imgs, masks)
        #print(attention.size(), )
        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        d_loss.backward(retain_graph=True)

        optD.step()

        # Optimize Generator
        optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad()
        pred_neg = netD(neg_imgs)
        #pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
        g_loss = GANLoss(pred_neg)
        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))
        whole_loss.backward()

        optG.step()

        # Update time recorder
        batch_time.update(time.time() - end)
        '''
コード例 #2
0
def main():
    logger_init()
    dataset_type = config.DATASET
    batch_size = config.BATCH_SIZE

    # Dataset setting
    logger.info("Initialize the dataset...")
    val_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][1], \
                                 {mask_type: config.DATA_FLIST[config.MASKDATASET][mask_type][1] for mask_type in
                                  ('val',)}, \
                                 resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \
                                 random_bbox_margin=config.RANDOM_BBOX_MARGIN,
                                 random_ff_setting=config.RANDOM_FF_SETTING)
    val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1)
    # print(len(val_loader))

    ### Generate a new val data

    logger.info("Finish the dataset initialization.")

    # Define the Network Structure
    logger.info("Define the Network Structure and Losses")
    whole_model_path = 'model_logs/{}'.format(config.MODEL_RESTORE)
    nets = torch.load(whole_model_path)
    netG_state_dict, netD_state_dict = nets['netG_state_dict'], nets[
        'netD_state_dict']
    if config.NETWORK_TYPE == "l2h_unet":
        netG = InpaintRUNNet(n_in_channel=config.N_CHANNEL)
        netG.load_state_dict(netG_state_dict)

    elif config.NETWORK_TYPE == 'sa_gated':
        netG = InpaintSANet()
        load_consistent_state_dict(netG_state_dict, netG)
        # netG.load_state_dict(netG_state_dict)

    netD = InpaintSADirciminator()
    netVGG = vgg16_bn(pretrained=True)

    # netD.load_state_dict(netD_state_dict)
    logger.info("Loading pretrained models from {} ...".format(
        config.MODEL_RESTORE))

    # Define loss
    recon_loss = ReconLoss(*(config.L1_LOSS_ALPHA))
    gan_loss = SNGenLoss(config.GAN_LOSS_ALPHA)
    perc_loss = PerceptualLoss(weight=config.PERC_LOSS_ALPHA,
                               feat_extractors=netVGG.to(cuda1))
    style_loss = StyleLoss(weight=config.STYLE_LOSS_ALPHA,
                           feat_extractors=netVGG.to(cuda1))
    dis_loss = SNDisLoss()
    lr, decay = config.LEARNING_RATE, config.WEIGHT_DECAY
    optG = torch.optim.Adam(netG.parameters(), lr=lr, weight_decay=decay)
    optD = torch.optim.Adam(netD.parameters(), lr=4 * lr, weight_decay=decay)
    nets = {"netG": netG, "netD": netD, "vgg": netVGG}

    losses = {
        "GANLoss": gan_loss,
        "ReconLoss": recon_loss,
        "StyleLoss": style_loss,
        "DLoss": dis_loss,
        "PercLoss": perc_loss
    }
    opts = {
        "optG": optG,
        "optD": optD,
    }
    logger.info("Finish Define the Network Structure and Losses")

    # Start Training
    logger.info("Start Validation")

    validate(nets,
             losses,
             opts,
             val_loader,
             0,
             config.NETWORK_TYPE,
             devices=(cuda0, cuda1))
コード例 #3
0
def validate(nets,
             loss_terms,
             opts,
             dataloader,
             epoch,
             network_type,
             devices=(cuda0, cuda1),
             batch_n="whole_test_show"):
    """
    validate phase
    """
    netD, netG = nets["netD"], nets["netG"]
    ReconLoss, DLoss, PercLoss, GANLoss, StyleLoss = loss_terms[
        'ReconLoss'], loss_terms['DLoss'], loss_terms["PercLoss"], loss_terms[
            "GANLoss"], loss_terms["StyleLoss"]
    optG, optD = opts['optG'], opts['optD']
    device0, device1 = devices
    netG.to(device0)
    netD.to(device0)
    netG.eval()
    netD.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "p_loss": AverageMeter(),
        "s_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        "d_loss": AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    val_save_dir = os.path.join(
        result_dir, "val_{}_{}".format(
            epoch, batch_n if isinstance(batch_n, str) else batch_n + 1))
    val_save_real_dir = os.path.join(val_save_dir, "real")
    val_save_gen_dir = os.path.join(val_save_dir, "gen")
    val_save_comp_dir = os.path.join(val_save_dir, "comp")
    for size in SIZES_TAGS:
        if not os.path.exists(os.path.join(val_save_real_dir, size)):
            os.makedirs(os.path.join(val_save_real_dir, size))
        if not os.path.exists(os.path.join(val_save_gen_dir, size)):
            os.makedirs(os.path.join(val_save_gen_dir, size))
        if not os.path.exists(os.path.join(val_save_comp_dir, size)):
            os.makedirs(os.path.join(val_save_comp_dir, size))
    info = {}
    t = 0
    for i, (ori_imgs, ori_masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        pre_imgs = ori_imgs
        pre_complete_imgs = (pre_imgs / 127.5 - 1)

        for s_i, size in enumerate(TRAIN_SIZES):

            masks = ori_masks['val']
            masks = F.interpolate(masks, size)
            masks = (masks > 0).type(torch.FloatTensor)
            imgs = F.interpolate(ori_imgs, size)
            if imgs.size(1) != 3:
                print(t, imgs.size())
            pre_inter_imgs = F.interpolate(pre_complete_imgs, size)

            imgs, masks, pre_complete_imgs, pre_inter_imgs = imgs.to(
                device0), masks.to(device0), pre_complete_imgs.to(
                    device0), pre_inter_imgs.to(device0)
            # masks = (masks > 0).type(torch.FloatTensor)

            # imgs, masks = imgs.to(device), masks.to(device)
            imgs = (imgs / 127.5 - 1)
            # mask is 1 on masked region
            # forward
            if network_type == 'l2h_unet':
                recon_imgs = netG(imgs, masks, pre_complete_imgs,
                                  pre_inter_imgs, size)
            elif network_type == 'l2h_gated':
                recon_imgs = netG(imgs, masks, pre_inter_imgs)
            elif network_type == 'sa_gated':
                recon_imgs, _ = netG(imgs, masks)
            complete_imgs = recon_imgs * masks + imgs * (1 - masks)

            pos_imgs = torch.cat(
                [imgs, masks, torch.full_like(masks, 1.)], dim=1)
            neg_imgs = torch.cat(
                [recon_imgs, masks,
                 torch.full_like(masks, 1.)], dim=1)
            pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

            pred_pos_neg = netD(pos_neg_imgs)
            pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)

            g_loss = GANLoss(pred_neg)

            r_loss = ReconLoss(imgs, recon_imgs, recon_imgs, masks)

            imgs, recon_imgs, complete_imgs = imgs.to(device1), recon_imgs.to(
                device1), complete_imgs.to(device1)
            p_loss = PercLoss(imgs, recon_imgs) + PercLoss(imgs, complete_imgs)
            s_loss = StyleLoss(imgs, recon_imgs) + StyleLoss(
                imgs, complete_imgs)
            p_loss, s_loss = p_loss.to(device0), s_loss.to(device0)
            imgs, recon_imgs, complete_imgs = imgs.to(device0), recon_imgs.to(
                device0), complete_imgs.to(device0)

            whole_loss = r_loss + p_loss  # g_loss + r_loss

            # Update the recorder for losses
            losses['g_loss'].update(g_loss.item(), imgs.size(0))
            losses['r_loss'].update(r_loss.item(), imgs.size(0))
            losses['p_loss'].update(p_loss.item(), imgs.size(0))
            losses['s_loss'].update(s_loss.item(), imgs.size(0))
            losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

            d_loss = DLoss(pred_pos, pred_neg)
            losses['d_loss'].update(d_loss.item(), imgs.size(0))
            pre_complete_imgs = complete_imgs
            # Update time recorder
            batch_time.update(time.time() - end)

            # Logger logging

            # if t < config.STATIC_VIEW_SIZE:
            print(i, size)
            real_img = img2photo(imgs)
            gen_img = img2photo(recon_imgs)
            comp_img = img2photo(complete_imgs)

            real_img = Image.fromarray(real_img[0].astype(np.uint8))
            gen_img = Image.fromarray(gen_img[0].astype(np.uint8))
            comp_img = Image.fromarray(comp_img[0].astype(np.uint8))
            real_img.save(
                os.path.join(val_save_real_dir, SIZES_TAGS[s_i],
                             "{}.png".format(i)))
            gen_img.save(
                os.path.join(val_save_gen_dir, SIZES_TAGS[s_i],
                             "{}.png".format(i)))
            comp_img.save(
                os.path.join(val_save_comp_dir, SIZES_TAGS[s_i],
                             "{}.png".format(i)))

            end = time.time()
コード例 #4
0
def validate(nets, loss_terms, opts, dataloader, epoch, devices=(cuda0, cuda1), batch_n="whole"):
    """
    validate phase
    """
    netG, netD  = nets["netG"], nets["netD"]
    GANLoss, ReconLoss, L1ReconLoss, DLoss = loss_terms["GANLoss"], loss_terms["ReconLoss"], loss_terms["L1ReconLoss"], loss_terms["DLoss"]
    optG, optD = opts["optG"], opts["optD"]
    device0, device1 = devices[0], devices[1]
    netG.to(device0)
    netD.to(device0)
    # maskNetD.to(device1)

    netG.eval()
    netD.eval()
    # maskNetD.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {"g_loss":AverageMeter(), "r_loss":AverageMeter(), "r_ex_loss":AverageMeter(), "whole_loss":AverageMeter(), 'd_loss':AverageMeter(),
              'mask_d_loss':AverageMeter(), 'mask_rec_loss':AverageMeter(),'mask_whole_loss':AverageMeter()}

    end = time.time()
    val_save_dir = os.path.join(result_dir, "val_{}_{}".format(epoch, batch_n+1))
    val_save_real_dir = os.path.join(val_save_dir, "real")
    val_save_gen_dir = os.path.join(val_save_dir, "gen")
    val_save_inf_dir = os.path.join(val_save_dir, "inf")
    if not os.path.exists(val_save_real_dir):
        os.makedirs(val_save_real_dir)
        os.makedirs(val_save_gen_dir)
        os.makedirs(val_save_inf_dir)
    info = {}

    for i, data in enumerate(dataloader):

        data_time.update(time.time() - end, 1)
        imgs, img_exs, masks = data
        masks = masks['val']
        #masks = (masks > 0).type(torch.FloatTensor)

        imgs, img_exs, masks = imgs.to(device0), img_exs.to(device0), masks.to(device0)
        imgs = (imgs / 127.5 - 1)
        img_exs = (img_exs / 127.5 - 1)
        # mask is 1 on masked region
        # forward
        coarse_imgs, recon_imgs, recon_ex_imgs = netG(imgs, img_exs, masks)

        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat([complete_imgs, masks, torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)
        #mask_pos_neg_imgs = torch.cat([imgs, complete_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)

        # # Mask Gan
        # mask_pos_neg_imgs = mask_pos_neg_imgs.to(device1)
        # mask_pred_pos_neg = maskNetD(mask_pos_neg_imgs)
        # mask_pred_pos, mask_pred_neg = torch.chunk(mask_pred_pos_neg, 2, dim=0)

        g_loss = GANLoss(pred_neg)

        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        r_ex_loss = L1ReconLoss(img_exs, recon_ex_imgs)

        whole_loss = g_loss + r_loss + r_ex_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['r_ex_loss'].update(r_ex_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))

        # masks = masks.to(device1)
        # mask_d_loss = DLoss(mask_pred_pos*masks + (1-masks), mask_pred_neg*masks + (1-masks))
        # mask_rec_loss = L1ReconLoss(mask_pred_neg, masks)
        # mask_whole_loss = mask_rec_loss

        # masks = masks.to(device0)
        # losses['mask_d_loss'].update(mask_d_loss.item(), imgs.size(0))
        # losses['mask_rec_loss'].update(mask_rec_loss.item(), imgs.size(0))
        # losses['mask_whole_loss'].update(mask_whole_loss.item(), imgs.size(0))

        # Update time recorder
        batch_time.update(time.time() - end, 1)


        # Logger logging

        if (i+1) < config.STATIC_VIEW_SIZE:

            def img2photo(imgs):
                return ((imgs+1)*127.5).transpose(1,2).transpose(2,3).detach().cpu().numpy()
            # info = { 'val/ori_imgs':img2photo(imgs),
            #          'val/coarse_imgs':img2photo(coarse_imgs),
            #          'val/recon_imgs':img2photo(recon_imgs),
            #          'val/comp_imgs':img2photo(complete_imgs),
            info['val/whole_imgs/{}'.format(i)] = {"img":img2photo(torch.cat([imgs * (1 - masks), coarse_imgs, recon_imgs, imgs, complete_imgs], dim=3)),
                                                   }

        else:
            logger.info("Validation Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f},\t Whole Gen Loss:{whole_loss.val:.4f}\t,"
                        "Recon Loss:{r_loss.val:.4f},\t Ex Recon Loss:{r_ex_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}"
                        .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], \
                                r_loss=losses['r_loss'], r_ex_loss=losses['r_ex_loss'], g_loss=losses['g_loss'], d_loss=losses['d_loss']))

            for tag, value in losses.items():
                tensorboardlogger.scalar_summary('val/avg_'+tag, value.avg, epoch*len(dataloader)+i)
            j = 0
            for tag, datas in info.items():
                images = datas["img"]
                h, w = images.shape[1], images.shape[2] // 5
                for kv, val_img in enumerate(images):
                    real_img = val_img[:,(3*w):(4*w),:]
                    gen_img = val_img[:,(4*w):(5*w),:]
                    real_img = Image.fromarray(real_img.astype(np.uint8))
                    gen_img = Image.fromarray(gen_img.astype(np.uint8))
                    #pkl.dump({datas[term][kv] for term in datas if term != "img"}, open(os.path.join(val_save_inf_dir, "{}.png".format(j)), 'wb'))
                    real_img.save(os.path.join(val_save_real_dir, "{}.png".format(j)))
                    gen_img.save(os.path.join(val_save_gen_dir, "{}.png".format(j)))
                    j += 1
                tensorboardlogger.image_summary(tag, images, epoch)
            path1, path2 = val_save_real_dir, val_save_gen_dir
            fid_score = metrics['fid']([path1, path2], cuda=False)
            ssim_score = metrics['ssim']([path1, path2])
            tensorboardlogger.scalar_summary('val/fid', fid_score.item(), epoch*len(dataloader)+i)
            tensorboardlogger.scalar_summary('val/ssim', ssim_score.item(), epoch*len(dataloader)+i)
            break
            
        end = time.time()
コード例 #5
0
ファイル: train_sagan.py プロジェクト: GothicAi/DIP
def train(netG,
          netD,
          GANLoss,
          ReconLoss,
          DLoss,
          optG,
          optD,
          dataloader,
          epoch,
          device=cuda0,
          val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    netG.to(device)
    netD.to(device)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        'd_loss': AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()

    for i, (imgs, masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        masks = masks['random_free_form']

        #traditional inpainting
        for item in range(len(imgs)):
            img = np.array(transforms.ToPILImage()(imgs[item]))
            mask = np.array(transforms.ToPILImage()(masks[item]))
            res = cv2.inpaint(img, mask, 3, cv2.INPAINT_TELEA)
            res = transforms.ToTensor()(res)
            res = (res * 255) / 127.5 - 1
            if item:
                traditional_inpaint = torch.cat((traditional_inpaint, res))
            else:
                traditional_inpaint = res
        traditional_inpaint = torch.reshape(traditional_inpaint,
                                            (config.BATCH_SIZE, 3, 256, 256))
        traditional_inpaint = traditional_inpaint.to(device)

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()

        imgs, masks = imgs.to(device), masks.to(device)
        imgs = (imgs / 127.5 - 1)
        # mask is 1 on masked region
        #print(type(masks))
        #print(type(guidence))
        #exit()

        coarse_imgs, recon_imgs_with_weight = netG(imgs, masks)
        recon_imgs = recon_imgs_with_weight[:, 0:3, :, :]
        weight_layer = (recon_imgs_with_weight[:, 3:, :, :] + 1.0) / 2
        recon_imgs = weight_layer * recon_imgs + (
            1 - weight_layer) * traditional_inpaint
        #print(attention.size(), )
        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        d_loss.backward(retain_graph=True)

        optD.step()

        # Optimize Generator
        optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad()
        pred_neg = netD(neg_imgs)
        #pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
        g_loss = GANLoss(pred_neg)
        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))
        whole_loss.backward()

        optG.step()

        # Update time recorder
        batch_time.update(time.time() - end)

        if (i + 1) % config.SUMMARY_FREQ == 0:
            # Logger logging
            logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t,"
                        "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" \
                        .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                        ,g_loss=losses['g_loss'], d_loss=losses['d_loss']))
            # Tensorboard logger for scaler and images
            info_terms = {
                'WGLoss': whole_loss.item(),
                'ReconLoss': r_loss.item(),
                "GANLoss": g_loss.item(),
                "DLoss": d_loss.item()
            }

            for tag, value in info_terms.items():
                tensorboardlogger.scalar_summary(tag, value,
                                                 epoch * len(dataloader) + i)

            for tag, value in losses.items():
                tensorboardlogger.scalar_summary('avg_' + tag, value.avg,
                                                 epoch * len(dataloader) + i)

            def img2photo(imgs):
                return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                    2, 3).detach().cpu().numpy()

            # info = { 'train/ori_imgs':img2photo(imgs),
            #          'train/coarse_imgs':img2photo(coarse_imgs),
            #          'train/recon_imgs':img2photo(recon_imgs),
            #          'train/comp_imgs':img2photo(complete_imgs),
            info = {
                'train/whole_imgs':
                img2photo(
                    torch.cat([
                        imgs * (1 - masks), coarse_imgs, recon_imgs, imgs,
                        complete_imgs
                    ],
                              dim=3))
            }

            for tag, images in info.items():
                tensorboardlogger.image_summary(tag, images,
                                                epoch * len(dataloader) + i)
        if (i + 1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None:

            validate(netG,
                     netD,
                     GANLoss,
                     ReconLoss,
                     DLoss,
                     optG,
                     optD,
                     val_datas,
                     epoch,
                     device,
                     batch_n=i)
            netG.train()
            netD.train()
        end = time.time()
コード例 #6
0
def main():
    logger_init()
    dataset_type = config.DATASET
    batch_size = config.BATCH_SIZE

    # Dataset setting
    logger.info("Initialize the dataset...")
    train_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][0], \
                                   {mask_type: config.DATA_FLIST[config.MASKDATASET][mask_type][0] for mask_type in
                                    config.MASK_TYPES}, \
                                   resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \
                                   random_bbox_margin=config.RANDOM_BBOX_MARGIN,
                                   random_ff_setting=config.RANDOM_FF_SETTING)
    train_loader = train_dataset.loader(batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=16,
                                        pin_memory=True)

    val_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][1], \
                                 {mask_type: config.DATA_FLIST[config.MASKDATASET][mask_type][1] for mask_type in
                                  ('val',)}, \
                                 resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \
                                 random_bbox_margin=config.RANDOM_BBOX_MARGIN,
                                 random_ff_setting=config.RANDOM_FF_SETTING)
    val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1)
    # print(len(val_loader))

    ### Generate a new val data
    val_datas = []
    j = 0
    for i, data in enumerate(val_loader):
        if j < config.STATIC_VIEW_SIZE:
            imgs = data[0]
            if imgs.size(1) == 3:
                val_datas.append(data)
                j += 1
        else:
            break
    # val_datas = [(imgs, masks) for imgs, masks in val_loader]

    val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1)
    logger.info("Finish the dataset initialization.")

    # Define the Network Structure
    logger.info("Define the Network Structure and Losses")
    netG = InpaintSANet()
    netD = InpaintSADirciminator()

    if config.MODEL_RESTORE != '':
        whole_model_path = 'model_logs/{}'.format(config.MODEL_RESTORE)
        nets = torch.load(whole_model_path)
        netG_state_dict, netD_state_dict = nets['netG_state_dict'], nets[
            'netD_state_dict']
        netG.load_state_dict(netG_state_dict)
        netD.load_state_dict(netD_state_dict)
        logger.info("Loading pretrained models from {} ...".format(
            config.MODEL_RESTORE))

    # Define loss
    recon_loss = ReconLoss(*(config.L1_LOSS_ALPHA))
    gan_loss = SNGenLoss(config.GAN_LOSS_ALPHA)
    dis_loss = SNDisLoss()
    lr, decay = config.LEARNING_RATE, config.WEIGHT_DECAY
    optG = torch.optim.Adam(netG.parameters(), lr=lr, weight_decay=decay)
    optD = torch.optim.Adam(netD.parameters(), lr=4 * lr, weight_decay=decay)

    logger.info("Finish Define the Network Structure and Losses")

    # Start Training
    logger.info("Start Training...")
    epoch = 50

    for i in range(epoch):
        # validate(netG, netD, gan_loss, recon_loss, dis_loss, optG, optD, val_loader, i, device=cuda0)

        # train data
        train(netG,
              netD,
              gan_loss,
              recon_loss,
              dis_loss,
              optG,
              optD,
              train_loader,
              i,
              device=cuda0,
              val_datas=val_datas)

        # validate
        validate(netG,
                 netD,
                 gan_loss,
                 recon_loss,
                 dis_loss,
                 optG,
                 optD,
                 val_datas,
                 i,
                 device=cuda0)

        saved_model = {
            'epoch': i + 1,
            'netG_state_dict': netG.to(cpu0).state_dict(),
            'netD_state_dict': netD.to(cpu0).state_dict(),
            # 'optG' : optG.state_dict(),
            # 'optD' : optD.state_dict()
        }
        torch.save(saved_model,
                   '{}/epoch_{}_ckpt.pth.tar'.format(log_dir, i + 1))
        torch.save(saved_model,
                   '{}/latest_ckpt.pth.tar'.format(log_dir, i + 1))
コード例 #7
0
def main(args):
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)

    dataset_type = args.dataset

    # Dataset setting
    train_dataset = InpaintDataset(args.train_image_list,\
                                      {'val':args.train_mask_list},
                                      mode='train', img_size=args.img_shape)
    train_loader = train_dataset.loader(batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=4,
                                        pin_memory=True)

    val_dataset = InpaintDataset(args.val_image_list,\
                                      {'val':args.val_mask_list},
                                      # {'val':args.val_mask_list},
                                      mode='val', img_size=args.img_shape)
    val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1)

    # Define the Network Structure
    netG = InpaintSANet()
    netD = InpaintSADirciminator()
    netG.cuda()
    netD.cuda()

    if args.load_weights != '':
        whole_model_path = args.load_weights
        nets = torch.load(whole_model_path)
        netG_state_dict, netD_state_dict = nets['netG_state_dict'], nets[
            'netD_state_dict']
        # netG.load_state_dict(netG_state_dict)
        load_consistent_state_dict(netG_state_dict, netG)
        netD.load_state_dict(netD_state_dict)

    # Define loss
    recon_loss = ReconLoss(*([1.2, 1.2, 1.2, 1.2]))
    gan_loss = SNGenLoss(0.005)
    dis_loss = SNDisLoss()
    lr, decay = args.learning_rate, 0.0
    optG = torch.optim.Adam(netG.parameters(), lr=lr, weight_decay=decay)
    optD = torch.optim.Adam(netD.parameters(), lr=4 * lr, weight_decay=decay)

    best_score = 0

    # Create loss and acc file
    loss_writer = csv.writer(open(os.path.join(args.logdir, 'loss.csv'), 'w'),
                             delimiter=',')
    acc_writer = csv.writer(open(os.path.join(args.logdir, 'acc.csv'), 'w'),
                            delimiter=',')

    # Start Training
    for i in range(args.epochs):
        #train data
        train(netG, netD, gan_loss, recon_loss, dis_loss, optG, optD,
              train_loader, i + 1, args.img_shape, loss_writer)

        # validate
        output_dir = os.path.join(args.result_dir, str(i + 1))
        mse, ssim = validate(netG, val_loader, args.img_shape, output_dir,
                             args.gt_dir)
        score = 1 - mse / 100 + ssim
        print('MSE: ', mse, '     SSIM:', ssim, '     SCORE:', score)
        acc_writer.writerow([i + 1, mse, ssim, score])

        saved_model = {
            'epoch': i + 1,
            'netG_state_dict': netG.state_dict(),
            'netD_state_dict': netD.state_dict(),
            # 'optG' : optG.state_dict(),
            # 'optD' : optD.state_dict()
        }
        torch.save(saved_model,
                   '{}/epoch_{}_ckpt.pth.tar'.format(args.logdir, i + 1))
        if score > best_score:
            torch.save(saved_model,
                       '{}/best_ckpt.pth.tar'.format(args.logdir, i + 1))
            best_score = score
            print('New best score at epoch', i + 1)
コード例 #8
0
def train(netG,
          netD,
          GANLoss,
          ReconLoss,
          DLoss,
          optG,
          optD,
          dataloader,
          epoch,
          device=cuda0):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        'd_loss': AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    for i, (imgs, masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        masks = masks['random_free_form']
        #masks = (masks > 0).type(torch.FloatTensor)#
        #print(len([i for i in masks.numpy().flatten() if i != 0]))

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()

        imgs, masks = imgs.to(device), masks.to(device)
        imgs = (imgs / 127.5 - 1)
        # mask is 1 on masked region
        coarse_imgs, recon_imgs = netG(imgs, masks)

        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        #print(pred_pos.size())
        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        d_loss.backward(retain_graph=True)

        optD.step()

        # Optimize Generator
        optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad()
        pred_neg = netD(neg_imgs)
        #pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
        g_loss = GANLoss(pred_neg)
        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))
        whole_loss.backward()

        optG.step()

        # Update time recorder
        batch_time.update(time.time() - end)

        if i % config.SUMMARY_FREQ == 0:
            # Logger logging
            logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t,"
                        "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" \
                        .format(epoch, i, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                        ,g_loss=losses['g_loss'], d_loss=losses['d_loss']))
            # Tensorboard logger for scaler and images
            info_terms = {
                'WGLoss': whole_loss.item(),
                'ReconLoss': r_loss.item(),
                "GANLoss": g_loss.item(),
                "DLoss": d_loss.item()
            }

            for tag, value in info_terms.items():
                tensorboardlogger.scalar_summary(tag, value,
                                                 epoch * len(dataloader) + i)

            for tag, value in losses.items():
                tensorboardlogger.scalar_summary('avg_' + tag, value.avg,
                                                 epoch * len(dataloader) + i)

            def img2photo(imgs):
                return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                    2, 3).detach().cpu().numpy()

            info = {
                'train/ori_imgs':
                img2photo(imgs),
                'train/coarse_imgs':
                img2photo(coarse_imgs),
                'train/recon_imgs':
                img2photo(recon_imgs),
                'train/comp_imgs':
                img2photo(complete_imgs),
                'train/whole_imgs':
                img2photo(
                    torch.cat([imgs, coarse_imgs, recon_imgs, complete_imgs],
                              dim=3))
            }

            for tag, images in info.items():
                tensorboardlogger.image_summary(tag, images, i)
        end = time.time()
コード例 #9
0
def main():
    logger_init()
    dataset_type = config.DATASET
    batch_size = config.BATCH_SIZE

    # Dataset setting
    logger.info("Initialize the dataset...")
    train_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][0],\
                                      {mask_type:config.DATA_FLIST[config.MASKDATASET][mask_type][0] for mask_type in config.MASK_TYPES}, \
                                      resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \
                                      random_bbox_margin=config.RANDOM_BBOX_MARGIN,
                                      random_ff_setting=config.RANDOM_FF_SETTING)
    train_loader = train_dataset.loader(batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=16,
                                        pin_memory=True)

    val_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][1],\
                                    {mask_type:config.DATA_FLIST[config.MASKDATASET][mask_type][1] for mask_type in config.MASK_TYPES}, \
                                    resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \
                                    random_bbox_margin=config.RANDOM_BBOX_MARGIN,
                                    random_ff_setting=config.RANDOM_FF_SETTING)
    val_loader = val_dataset.loader(batch_size=batch_size,
                                    shuffle=False,
                                    num_workers=16,
                                    pin_memory=True)
    logger.info("Finish the dataset initialization.")

    # Define the Network Structure
    logger.info("Define the Network Structure and Losses")
    netG = InpaintGCNet()
    netD = InpaintDirciminator()
    netG, netD = netG.to(cuda0), netD.to(cuda0)

    # Define loss
    recon_loss = ReconLoss(*(config.L1_LOSS_ALPHA))
    gan_loss = SNGenLoss(config.GAN_LOSS_ALPHA)
    dis_loss = SNDisLoss()
    lr, decay = config.LEARNING_RATE, config.WEIGHT_DECAY
    optG = torch.optim.Adam(netG.parameters(), lr=lr, weight_decay=decay)
    optD = torch.optim.Adam(netD.parameters(), lr=4 * lr, weight_decay=decay)

    logger.info("Finish Define the Network Structure and Losses")

    # Start Training
    logger.info("Start Training...")
    epoch = 50

    for i in range(epoch):
        #train data
        train(netG,
              netD,
              gan_loss,
              recon_loss,
              dis_loss,
              optG,
              optD,
              train_loader,
              i,
              device=cuda0)

        # validate
        validate(netG,
                 netD,
                 gan_loss,
                 recon_loss,
                 dis_loss,
                 optG,
                 optD,
                 val_loader,
                 i,
                 device=cuda0)

        torch.save(
            {
                'epoch': i + 1,
                'netG_state_dict': netG.state_dict(),
                'netD_state_dict': netD.state_dict(),
                'optG': optG.state_dict(),
                'optD': optD.state_dict()
            }, '{}/epoch_{}_ckpt.pth.tar'.format(log_dir, i + 1))
コード例 #10
0
def main():
    logger_init()
    dataset_type = config.DATASET
    batch_size = config.BATCH_SIZE

    # Dataset setting
    logger.info("Initialize the dataset...")
    train_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][0],\
                                      {mask_type:config.DATA_FLIST[config.MASKDATASET][mask_type][0] for mask_type in config.MASK_TYPES}, \
                                      resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \
                                      random_bbox_margin=config.RANDOM_BBOX_MARGIN,
                                      random_ff_setting=config.RANDOM_FF_SETTING)
    train_loader = train_dataset.loader(batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=16)

    val_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][1],\
                                    {mask_type:config.DATA_FLIST[config.MASKDATASET][mask_type][1] for mask_type in ('val',)}, \
                                    resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \
                                    random_bbox_margin=config.RANDOM_BBOX_MARGIN,
                                    random_ff_setting=config.RANDOM_FF_SETTING)
    val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1)
    ### Generate a new val data
    val_datas = []
    j = 0
    for i, data in enumerate(val_loader):
        if j < config.STATIC_VIEW_SIZE:
            imgs = data[0]
            if imgs.size(1) == 3:
                val_datas.append(data)
                j += 1
        else:
            break
    #val_datas = [(imgs, masks) for imgs, masks in val_loader]

    val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1)
    logger.info("Finish the dataset initialization.")

    # Define the Network Structure
    logger.info("Define the Network Structure and Losses")
    netG = InpaintRUNNet(cuda0, n_in_channel=config.N_CHANNEL)
    netD = InpaintSADirciminator()
    netVGG = vgg16_bn(pretrained=True)
    sr_args = SRArgs(config.GPU_IDS[0])
    netSR = sr_model.Model(sr_args, sr_util.checkpoint(sr_args))

    if config.MODEL_RESTORE != '':
        whole_model_path = 'model_logs/{}'.format(config.MODEL_RESTORE)
        nets = torch.load(whole_model_path)
        netG_state_dict, netD_state_dict = nets['netG_state_dict'], nets[
            'netD_state_dict']
        netG.load_state_dict(netG_state_dict)
        netD.load_state_dict(netD_state_dict)
        logger.info("Loading pretrained models from {} ...".format(
            config.MODEL_RESTORE))

    # Define loss
    recon_loss = ReconLoss(*(config.L1_LOSS_ALPHA))
    gan_loss = SNGenLoss(config.GAN_LOSS_ALPHA)
    perc_loss = PerceptualLoss(weight=config.PERC_LOSS_ALPHA,
                               feat_extractors=netVGG.to(cuda1))
    style_loss = StyleLoss(weight=config.STYLE_LOSS_ALPHA,
                           feat_extractors=netVGG.to(cuda1))
    dis_loss = SNDisLoss()
    lr, decay = config.LEARNING_RATE, config.WEIGHT_DECAY
    optG = torch.optim.Adam(netG.parameters(), lr=lr, weight_decay=decay)
    optD = torch.optim.Adam(netD.parameters(), lr=4 * lr, weight_decay=decay)

    nets = {"netG": netG, "netD": netD, "vgg": netVGG, "netSR": netSR}

    losses = {
        "GANLoss": gan_loss,
        "ReconLoss": recon_loss,
        "StyleLoss": style_loss,
        "DLoss": dis_loss,
        "PercLoss": perc_loss
    }

    opts = {
        "optG": optG,
        "optD": optD,
    }

    logger.info("Finish Define the Network Structure and Losses")

    # Start Training
    logger.info("Start Training...")
    epoch = 50

    for i in range(epoch):
        #validate(netG, netD, gan_loss, recon_loss, dis_loss, optG, optD, val_loader, i, device=cuda0)

        #train data
        train(nets,
              losses,
              opts,
              train_loader,
              i,
              devices=(cuda0, cuda1),
              val_datas=val_datas)

        # validate
        validate(nets, losses, opts, val_datas, i, devices=(cuda0, cuda1))

        saved_model = {
            'epoch': i + 1,
            'netG_state_dict': netG.to(cpu0).state_dict(),
            'netD_state_dict': netD.to(cpu0).state_dict(),
            # 'optG' : optG.state_dict(),
            # 'optD' : optD.state_dict()
        }
        torch.save(saved_model,
                   '{}/epoch_{}_ckpt.pth.tar'.format(log_dir, i + 1))
        torch.save(saved_model,
                   '{}/latest_ckpt.pth.tar'.format(log_dir, i + 1))
コード例 #11
0
def validate(nets,
             loss_terms,
             opts,
             dataloader,
             epoch,
             devices=(cuda0, cuda1),
             batch_n="whole"):
    """
    validate phase
    """
    netD, netG, netSR = nets["netD"], nets["netG"], nets["netSR"]
    ReconLoss, DLoss, PercLoss, GANLoss, StyleLoss = loss_terms[
        'ReconLoss'], loss_terms['DLoss'], loss_terms["PercLoss"], loss_terms[
            "GANLoss"], loss_terms["StyleLoss"]
    optG, optD = opts['optG'], opts['optD']
    device0, device1 = devices
    netG.to(device0)
    netD.to(device0)
    netSR.to(device0)
    netG.eval()
    netD.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "p_loss": AverageMeter(),
        "s_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        "d_loss": AverageMeter()
    }

    netG.train()
    netD.train()

    end = time.time()
    val_save_dir = os.path.join(
        result_dir, "val_{}_{}".format(
            epoch, batch_n if isinstance(batch_n, str) else batch_n + 1))
    val_save_real_dir = os.path.join(val_save_dir, "real")
    val_save_gen_dir = os.path.join(val_save_dir, "gen")
    val_save_inf_dir = os.path.join(val_save_dir, "inf")
    if not os.path.exists(val_save_real_dir):
        os.makedirs(val_save_real_dir)
        os.makedirs(val_save_gen_dir)
        os.makedirs(val_save_inf_dir)
    info = {}

    for i, (ori_imgs, ori_masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        pre_imgs = ori_imgs
        pre_complete_imgs = (pre_imgs / 127.5 - 1)
        pre_complete_imgs = pre_complete_imgs * (
            1 - ori_masks['val']) + ori_masks['val']
        pre_inter_imgs = F.interpolate(pre_complete_imgs, TRAIN_SIZES[0])
        for s_j, size in enumerate(TRAIN_SIZES):

            masks = ori_masks['val']
            masks = F.interpolate(masks, size)
            masks = (masks > 0).type(torch.FloatTensor)
            imgs = F.interpolate(ori_imgs, size)
            if s_j == 0:
                pre_inter_imgs = F.interpolate(pre_complete_imgs, size)
            else:
                pre_complete_imgs = (pre_complete_imgs + 1) * 127.5
                pre_inter_imgs = netSR(pre_complete_imgs, 2)
                pre_inter_imgs = (pre_inter_imgs / 127.5 - 1)
            #upsampled_imgs = pre_inter_imgs
            imgs, masks, pre_complete_imgs, pre_inter_imgs = imgs.to(
                device0), masks.to(device0), pre_complete_imgs.to(
                    device0), pre_inter_imgs.to(device0)
            #masks = (masks > 0).type(torch.FloatTensor)
            upsampled_imgs = pre_inter_imgs
            #imgs, masks = imgs.to(device), masks.to(device)
            imgs = (imgs / 127.5 - 1)
            # mask is 1 on masked region
            # forward
            recon_imgs = netG(imgs, masks, pre_complete_imgs, pre_inter_imgs,
                              size)

            complete_imgs = recon_imgs * masks + imgs * (1 - masks)

            pos_imgs = torch.cat(
                [imgs, masks, torch.full_like(masks, 1.)], dim=1)
            neg_imgs = torch.cat(
                [complete_imgs, masks,
                 torch.full_like(masks, 1.)], dim=1)
            pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

            pred_pos_neg = netD(pos_neg_imgs)
            pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)

            g_loss = GANLoss(pred_neg)

            r_loss = ReconLoss(imgs, recon_imgs, recon_imgs, masks)

            imgs, recon_imgs, complete_imgs = imgs.to(device1), recon_imgs.to(
                device1), complete_imgs.to(device1)
            p_loss = PercLoss(imgs, recon_imgs) + PercLoss(imgs, complete_imgs)
            #s_loss = StyleLoss(imgs, recon_imgs) + StyleLoss(imgs, complete_imgs)
            #p_loss, s_loss = p_loss.to(device0), s_loss.to(device0)
            p_loss = p_loss.to(device0)
            imgs, recon_imgs, complete_imgs = imgs.to(device0), recon_imgs.to(
                device0), complete_imgs.to(device0)

            whole_loss = r_loss + p_loss + g_loss  #+ s_loss#g_loss + r_loss

            # Update the recorder for losses
            losses['g_loss'].update(g_loss.item(), imgs.size(0))
            losses['r_loss'].update(r_loss.item(), imgs.size(0))
            losses['p_loss'].update(p_loss.item(), imgs.size(0))
            losses['s_loss'].update(0, imgs.size(0))
            losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

            d_loss = DLoss(pred_pos, pred_neg)
            losses['d_loss'].update(d_loss.item(), imgs.size(0))
            pre_complete_imgs = complete_imgs
            # Update time recorder
            batch_time.update(time.time() - end)

            # Logger logging

            if i + 1 < config.STATIC_VIEW_SIZE:

                def img2photo(imgs):
                    return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                        2, 3).detach().cpu().numpy()

                # info = { 'val/ori_imgs':img2photo(imgs),
                #          'val/coarse_imgs':img2photo(coarse_imgs),
                #          'val/recon_imgs':img2photo(recon_imgs),
                #          'val/comp_imgs':img2photo(complete_imgs),
                info['val/{}whole_imgs/{}'.format(size, i)] = img2photo(
                    torch.cat([
                        imgs * (1 - masks), upsampled_imgs, recon_imgs, imgs,
                        complete_imgs
                    ],
                              dim=3))

            else:
                logger.info("Validation Epoch {0}, [{1}/{2}]: Size:{size}, Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f},\t Whole Gen Loss:{whole_loss.val:.4f}\t,"
                            "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f},\t Perc Loss:{p_loss.val:.4f},\tStyle Loss:{s_loss.val:.4f}"
                            .format(epoch, i+1, len(dataloader),size=size, batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                            ,g_loss=losses['g_loss'], d_loss=losses['d_loss'], p_loss=losses['p_loss'], s_loss=losses['s_loss']))
                j = 0
                for size in SIZES_TAGS:
                    if not os.path.exists(os.path.join(val_save_real_dir,
                                                       size)):
                        os.makedirs(os.path.join(val_save_real_dir, size))
                        os.makedirs(os.path.join(val_save_gen_dir, size))

                for tag, images in info.items():
                    h, w = images.shape[1], images.shape[2] // 5
                    s_i = 0
                    for i_, s in enumerate(TRAIN_SIZES):
                        if "{}".format(s) in tag:
                            size_tag = "{}".format(s)
                            s_i = i_
                            break

                    for val_img in images:
                        real_img = val_img[:, (3 * w):(4 * w), :]
                        gen_img = val_img[:, (4 * w):, :]
                        real_img = Image.fromarray(real_img.astype(np.uint8))
                        gen_img = Image.fromarray(gen_img.astype(np.uint8))
                        real_img.save(
                            os.path.join(val_save_real_dir, SIZES_TAGS[s_i],
                                         "{}_{}.png".format(size_tag, j)))
                        gen_img.save(
                            os.path.join(val_save_gen_dir, SIZES_TAGS[s_i],
                                         "{}_{}.png".format(size_tag, j)))
                        j += 1
                    tensorboardlogger.image_summary(tag, images, epoch)
                path1, path2 = os.path.join(
                    val_save_real_dir,
                    SIZES_TAGS[len(SIZES_TAGS) - 1]), os.path.join(
                        val_save_gen_dir, SIZES_TAGS[len(SIZES_TAGS) - 1])
                fid_score = metrics['fid']([path1, path2], cuda=False)
                ssim_score = metrics['ssim']([path1, path2])
                tensorboardlogger.scalar_summary('val/fid', fid_score.item(),
                                                 epoch * len(dataloader) + i)
                tensorboardlogger.scalar_summary('val/ssim', ssim_score.item(),
                                                 epoch * len(dataloader) + i)
                break

            end = time.time()
    saved_model = {
        'epoch': epoch + 1,
        'netG_state_dict': netG.to(cpu0).state_dict(),
        'netD_state_dict': netD.to(cpu0).state_dict(),
        # 'optG' : optG.state_dict(),
        # 'optD' : optD.state_dict()
    }
    torch.save(saved_model,
               '{}/latest_ckpt.pth.tar'.format(log_dir, epoch + 1))
コード例 #12
0
def train(nets,
          loss_terms,
          opts,
          dataloader,
          epoch,
          devices=(cuda0, cuda1),
          val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)
    """
    netD, netG, netSR = nets["netD"], nets["netG"], nets["netSR"]
    ReconLoss, DLoss, GANLoss, PercLoss, StyleLoss = loss_terms[
        'ReconLoss'], loss_terms['DLoss'], loss_terms['GANLoss'], loss_terms[
            "PercLoss"], loss_terms["StyleLoss"]
    optG, optD = opts['optG'], opts['optD']
    device0, device1 = devices
    netG.to(device0)
    netD.to(device0)
    netSR.to(device0)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "s_loss": AverageMeter(),
        'p_loss': AverageMeter(),
        "whole_loss": AverageMeter(),
        'd_loss': AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    for i, (ori_imgs, ori_masks) in enumerate(dataloader):
        ff_mask, rect_mask = ori_masks['random_free_form'], ori_masks[
            'random_bbox']
        if np.random.rand() < 0.3:
            ori_masks = rect_mask
        else:
            ori_masks = ff_mask
        # Optimize Discriminator

        # mask is 1 on masked region
        pre_complete_imgs = ori_imgs
        pre_complete_imgs = (pre_complete_imgs / 127.5 - 1)
        pre_complete_imgs = pre_complete_imgs * (1 - ori_masks) + ori_masks
        pre_complete_imgs = F.interpolate(pre_complete_imgs, TRAIN_SIZES[0])

        for s_j, size in enumerate(TRAIN_SIZES):
            data_time.update(time.time() - end)
            optD.zero_grad(), netD.zero_grad(), netG.zero_grad(
            ), optG.zero_grad()
            #Reshape
            masks = F.interpolate(ori_masks, size)
            masks = (masks > 0).type(torch.FloatTensor)
            imgs = F.interpolate(ori_imgs, size)
            if s_j == 0:
                pre_inter_imgs = F.interpolate(pre_complete_imgs, size)
            else:
                pre_complete_imgs = (pre_complete_imgs + 1) * 127.5
                pre_inter_imgs = netSR(pre_complete_imgs, 2)
                pre_inter_imgs = (pre_inter_imgs / 127.5 - 1)
            imgs, masks, pre_complete_imgs, pre_inter_imgs = imgs.to(
                device0), masks.to(device0), pre_complete_imgs.to(
                    device0), pre_inter_imgs.to(device0)
            imgs = (imgs / 127.5 - 1)
            upsampled_imgs = pre_inter_imgs

            recon_imgs = netG(imgs, masks, pre_complete_imgs, pre_inter_imgs,
                              size)
            #print(attention.size(), )
            complete_imgs = recon_imgs * masks + imgs * (1 - masks)

            pos_imgs = torch.cat(
                [imgs, masks, torch.full_like(masks, 1.)], dim=1)
            neg_imgs = torch.cat(
                [complete_imgs, masks,
                 torch.full_like(masks, 1.)], dim=1)
            pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

            pred_pos_neg = netD(pos_neg_imgs)
            pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
            d_loss = DLoss(pred_pos, pred_neg)
            losses['d_loss'].update(d_loss.item(), imgs.size(0))
            #print(size)
            if i % 3:
                d_loss.backward(retain_graph=True)

                optD.step()

            # Optimize Generator
            optD.zero_grad(), netD.zero_grad(), optG.zero_grad(
            ), netG.zero_grad()
            pred_neg = netD(neg_imgs)
            #pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
            g_loss = GANLoss(pred_neg)
            r_loss = ReconLoss(imgs, recon_imgs, recon_imgs, masks)

            imgs, recon_imgs, complete_imgs = imgs.to(device1), recon_imgs.to(
                device1), complete_imgs.to(device1)
            p_loss = PercLoss(imgs, recon_imgs) + PercLoss(imgs, complete_imgs)
            #s_loss = StyleLoss(imgs, recon_imgs) + StyleLoss(imgs, complete_imgs)
            #p_loss, s_loss = p_loss.to(device0), s_loss.to(device0)
            p_loss = p_loss.to(device0)
            imgs, recon_imgs, complete_imgs = imgs.to(device0), recon_imgs.to(
                device0), complete_imgs.to(device0)

            whole_loss = r_loss + p_loss + g_loss

            # Update the recorder for losses
            losses['g_loss'].update(g_loss.item(), imgs.size(0))
            losses['p_loss'].update(p_loss.item(), imgs.size(0))
            losses['s_loss'].update(0, imgs.size(0))
            losses['r_loss'].update(r_loss.item(), imgs.size(0))
            losses['whole_loss'].update(whole_loss.item(), imgs.size(0))
            whole_loss.backward(retain_graph=True)

            optG.step()

            pre_complete_imgs = complete_imgs

            # Update time recorder
            batch_time.update(time.time() - end)

            if (i + 1) % config.SUMMARY_FREQ == 0:
                # Logger logging
                logger.info("Epoch {0}, [{1}/{2}]:Size:{size} Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t,"
                            "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}, \t Perc Loss:{p_loss.val:.4f}, \t Style Loss:{s_loss.val:.4f}" \
                            .format(epoch, i+1, len(dataloader), size=size, batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                            ,g_loss=losses['g_loss'], d_loss=losses['d_loss'], p_loss=losses['p_loss'], s_loss=losses['s_loss']))
                # Tensorboard logger for scaler and images
                info_terms = {
                    '{}WGLoss'.format(size): whole_loss.item(),
                    '{}ReconLoss'.format(size): r_loss.item(),
                    "{}GANLoss".format(size): g_loss.item(),
                    "{}DLoss".format(size): d_loss.item(),
                    "{}PercLoss".format(size): p_loss.item()
                }

                for tag, value in info_terms.items():
                    tensorboardlogger.scalar_summary(
                        tag, value,
                        epoch * len(dataloader) + i)

                for tag, value in losses.items():
                    tensorboardlogger.scalar_summary(
                        'avg_' + tag, value.avg,
                        epoch * len(dataloader) + i)

                def img2photo(imgs):
                    return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                        2, 3).detach().cpu().numpy()

                # info = { 'train/ori_imgs':img2photo(imgs),
                #          'train/coarse_imgs':img2photo(coarse_imgs),
                #          'train/recon_imgs':img2photo(recon_imgs),
                #          'train/comp_imgs':img2photo(complete_imgs),
                info = {
                    'train/{}whole_imgs'.format(size):
                    img2photo(
                        torch.cat([
                            imgs * (1 - masks), upsampled_imgs, recon_imgs,
                            imgs, complete_imgs
                        ],
                                  dim=3))
                }

                for tag, images in info.items():
                    tensorboardlogger.image_summary(
                        tag, images,
                        epoch * len(dataloader) + i)
            end = time.time()
        if (i + 1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None:
            validate(nets,
                     loss_terms,
                     opts,
                     val_datas,
                     epoch,
                     devices,
                     batch_n=i)
            netG.train()
            netD.train()
            netG.to(device0)
            netD.to(device0)
コード例 #13
0
def validate(netG,
             netD,
             GANLoss,
             ReconLoss,
             DLoss,
             optG,
             optD,
             dataloader,
             epoch,
             device=cuda0,
             batch_n="whole"):
    """
    validate phase
    """
    netG.to(device)
    netD.to(device)
    netG.eval()
    netD.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        "d_loss": AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    val_save_dir = os.path.join(
        result_dir, "val_{}_{}".format(
            epoch, batch_n if isinstance(batch_n, str) else batch_n + 1))
    val_save_real_dir = os.path.join(val_save_dir, "real")
    val_save_gen_dir = os.path.join(val_save_dir, "gen")
    val_save_raw_dir = os.path.join(val_save_dir, 'raw')
    # val_save_inf_dir = os.path.join(val_save_dir, "inf")
    if not os.path.exists(val_save_real_dir):
        os.makedirs(val_save_real_dir)
        os.makedirs(val_save_gen_dir)
        os.makedirs(val_save_raw_dir)
    info = {}

    for i, (imgs, masks, mean, std) in enumerate(dataloader):

        data_time.update(time.time() - end)
        masks = masks['random_free_form']
        #masks = (masks > 0).type(torch.FloatTensor)

        imgs, masks = imgs.to(device), masks.to(device)
        # imgs = (imgs / 127.5 - 1)
        # mask is 1 on masked region
        # forward
        coarse_imgs, recon_imgs = netG(imgs, masks)

        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)

        g_loss = GANLoss(pred_neg)

        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        # Update time recorder
        batch_time.update(time.time() - end)

        # Logger logging

        if i + 1 < config.STATIC_VIEW_SIZE:

            def img2photo(imgs):
                return (imgs * 255).transpose(1, 2).transpose(
                    2, 3).detach().cpu().numpy()
                # return ((imgs+1)*127.5).transpose(1,2).transpose(2,3).detach().cpu().numpy()

            # info = { 'val/ori_imgs':img2photo(imgs),
            #          'val/coarse_imgs':img2photo(coarse_imgs),
            #          'val/recon_imgs':img2photo(recon_imgs),
            #          'val/comp_imgs':img2photo(complete_imgs),
            info['val/whole_imgs/{}'.format(i)] = img2photo(
                torch.cat(
                    [((imgs * std.cpu().numpy()[0]) + mean.cpu().numpy()[0]) *
                     (1 - masks) + masks,
                     ((coarse_imgs * std.cpu().numpy()[0]) +
                      mean.cpu().numpy()[0]),
                     ((recon_imgs * std.cpu().numpy()[0]) +
                      mean.cpu().numpy()[0]),
                     ((imgs * std.cpu().numpy()[0]) + mean.cpu().numpy()[0]),
                     ((complete_imgs * std.cpu().numpy()[0]) +
                      mean.cpu().numpy()[0])],
                    dim=3))

        else:
            logger.info("Validation Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f},\t Whole Gen Loss:{whole_loss.val:.4f}\t,"
                        "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}"
                        .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                        ,g_loss=losses['g_loss'], d_loss=losses['d_loss']))

            # vis.line([[r_loss.item(), (1 / config.GAN_LOSS_ALPHA) * g_loss.item(), d_loss.item()]],
            #          epoch*len(dataloader)+i, win='validation_loss', update='append')
            j = 0
            for tag, images in info.items():
                h, w = images.shape[1], images.shape[2] // 5
                for val_img in images:
                    raw_img = val_img[:, 0:w, :]
                    # raw_img = ((raw_img - np.min(raw_img)) / np.max(raw_img)) * 255
                    real_img = val_img[:, (3 * w):(4 * w), :]
                    # real_img = ((real_img - np.min(real_img)) / np.max(real_img)) * 255
                    gen_img = val_img[:, (4 * w):, :]
                    # gen_img = ((gen_img - np.min(gen_img)) / np.max(gen_img)) * 255

                    cv2.imwrite(
                        os.path.join(val_save_real_dir, "{}.png".format(j)),
                        real_img)
                    cv2.imwrite(
                        os.path.join(val_save_gen_dir, "{}.png".format(j)),
                        gen_img)
                    cv2.imwrite(
                        os.path.join(val_save_raw_dir, "{}.png".format(j)),
                        raw_img)
                    j += 1
                # tensorboardlogger.image_summary(tag, images, epoch)
            path1, path2 = val_save_real_dir, val_save_gen_dir
            fid_score = metrics['fid']([path1, path2], cuda=False)
            ssim_score = metrics['ssim']([path1, path2])
            # vis.line([[fid_score.item(),ssim_score.item()]], [epoch*len(dataloader)+i], win='validation_metric', update='append')
            # tensorboardlogger.scalar_summary('val/fid', fid_score.item(), epoch*len(dataloader)+i)
            # tensorboardlogger.scalar_summary('val/ssim', ssim_score.item(), epoch*len(dataloader)+i)
            break

        end = time.time()
    # vis.line([[losses['r_loss'].out(), (1 / config.GAN_LOSS_ALPHA) * losses['g_loss'].out(), losses['d_loss'].out()]],
    #          [epoch], win='validation_loss', update='append')
    wandb.log({
        "val_r_loss":
        losses['r_loss'].out(),
        "val_g_loss": (1 / config.GAN_LOSS_ALPHA) * losses['g_loss'].out(),
        "val_d_loss":
        losses['d_loss'].out()
    })
コード例 #14
0
def train(netG,
          netD,
          GANLoss,
          ReconLoss,
          DLoss,
          optG,
          optD,
          dataloader,
          epoch,
          device=cuda0,
          val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    # wandb.watch(netG, netD)
    netG.to(device)
    netD.to(device)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        'd_loss': AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()

    for i, (imgs, masks, mean, std) in enumerate(dataloader):
        data_time.update(time.time() - end)
        masks = masks['random_free_form']

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()

        imgs, masks = imgs.to(device), masks.to(device)
        # imgs = (imgs / 127.5 - 1)
        # mask is 1 on masked region

        coarse_imgs, recon_imgs = netG(imgs, masks)
        #print(attention.size(), )
        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        d_loss.backward(retain_graph=True)

        optD.step()

        # Optimize Generator
        optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad()
        pred_neg = netD(neg_imgs)
        #pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
        g_loss = GANLoss(pred_neg)
        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))
        whole_loss.backward()

        optG.step()

        # Update time recorder
        batch_time.update(time.time() - end)

        if (i + 1) % config.SUMMARY_FREQ == 0:
            # Logger logging
            logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t,"
                        "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" \
                        .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                        ,g_loss=losses['g_loss'], d_loss=losses['d_loss']))
            # Tensorboard logger for scaler and images
            info_terms = {
                'ReconLoss': losses['r_loss'],
                "GANLoss": losses['g_loss'],
                "DLoss": d_loss.item()
            }

            # vis.line([[r_loss.item(), (1 / config.GAN_LOSS_ALPHA) * g_loss.item(), d_loss.item()]],
            #          [epoch*len(dataloader)+i], win='train_loss', update='append')

            # for tag, value in info_terms.items():
            #     tensorboardlogger.scalar_summary(tag, value, epoch*len(dataloader)+i)
            #
            # for tag, value in losses.items():
            #     tensorboardlogger.scalar_summary('avg_'+tag, value.avg, epoch*len(dataloader)+i)

            def img2photo(imgs):
                # return (((imgs*0.263)+0.472)*255).transpose(1,2).transpose(2,3).detach().cpu().numpy()
                return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                    2, 3).detach().cpu().numpy()

            # info = { 'train/ori_imgs':img2photo(imgs),
            #          'train/coarse_imgs':img2photo(coarse_imgs),
            #          'train/recon_imgs':img2photo(recon_imgs),
            #          'train/comp_imgs':img2photo(complete_imgs),
            info = {
                'train/whole_imgs':
                img2photo(
                    torch.cat([
                        imgs * (1 - masks) + masks, coarse_imgs, recon_imgs,
                        imgs, complete_imgs
                    ],
                              dim=3))
            }

            # for tag, images in info.items():
            #     tensorboardlogger.image_summary(tag, images, epoch*len(dataloader)+i)
        if (i + 1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None:

            validate(netG,
                     netD,
                     GANLoss,
                     ReconLoss,
                     DLoss,
                     optG,
                     optD,
                     val_datas,
                     epoch,
                     device,
                     batch_n=i)
            netG.train()
            netD.train()
        end = time.time()
    # vis.line([[losses['r_loss'].out(), (1 / config.GAN_LOSS_ALPHA) * losses['g_loss'].out(), losses['d_loss'].out()]],
    #          [epoch], win='train_loss', update='append')
    wandb.log({
        "train_r_loss":
        losses['r_loss'].out(),
        "train_g_loss": (1 / config.GAN_LOSS_ALPHA) * losses['g_loss'].out(),
        "train_d_loss":
        losses['d_loss'].out()
    })
コード例 #15
0
def train(netG, netD, GANLoss, ReconLoss, DLoss, NLoss, optG, optD, dataloader, epoch, device=cuda0, val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    netG.to(device)
    netD.to(device)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {"g_loss":AverageMeter(), "r_loss":AverageMeter(), "whole_loss":AverageMeter(), "d_loss":AverageMeter(), 'n_loss': AverageMeter()}

    netG.train()
    netD.train()
    end = time.time()
    for i, (imgs, masks, gray) in enumerate(dataloader):
        data_time.update(time.time() - end)
        masks = masks['mine']
        # masks = masks['random_free_form']

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()

        imgs, masks, gray = imgs.to(device), masks.to(device), gray.to(device)
        # print(imgs.shape)
        masks = 1 - masks / 255.0 
        # masks = masks / 255.0 
        # 1 for masks, areas with holes
        # print(masks.min(), masks.max())
        imgs = (imgs / 127.5 - 1)
        gray = (gray / 127.5 - 1)
        # mask is 1 on masked region

        coarse_imgs, refined, mixed = netG(gray, masks)
        # coarse_imgs, mixed = netG(imgs, masks)
        # coarse_imgs, mixed, attention = netG(imgs, masks)
        #print(attention.size(), )
        # complete_imgs = mixed * masks + imgs * (1 - masks)
        complete_imgs = mixed # * masks + imgs * (1 - masks)
        # print(imgs.cpu().detach().max(), imgs.cpu().detach().min(), mixed.cpu().detach().max(), mixed.cpu().detach().min(), masks.cpu().detach().max(), masks.cpu().detach().min(), complete_imgs.cpu().detach().max(), complete_imgs.cpu().detach().min())

        pos_imgs = imgs
        neg_imgs = complete_imgs
        # pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        # neg_imgs = torch.cat([complete_imgs, masks, torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        d_loss.backward(retain_graph=True)

        optD.step()


        # Optimize Generator
        optD.zero_grad(), netD.zero_grad()
        optG.zero_grad(), netG.zero_grad()
        pred_neg = netD(neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
        g_loss = GANLoss(pred_neg)
        r_loss = ReconLoss(imgs, coarse_imgs, mixed, masks)
        n_loss = NLoss(coarse_imgs, refined, mixed, imgs)

        # whole_loss = r_loss + n_loss
        whole_loss = g_loss + r_loss + n_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['n_loss'].update(n_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

        whole_loss.backward()

        optG.step()

        # print('w?', imgs.min(), imgs.max())

        # Update time recorder
        batch_time.update(time.time() - end)

        # print(((imgs+1)*127.5).min(), ((imgs+1)*127.5).max())
        if (i+1) % config.SUMMARY_FREQ == 0:
            # Logger logging
                        # "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}, \t New Loss: {n_loss.val:.4f}"
            logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t,"
                        "Recon Loss:{r_loss.val:.4f}, \t New Loss: {n_loss.val:.4f}, \t D Loss: {d_loss.val:.4f}"
                        .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                        ,g_loss=losses['g_loss'], d_loss=losses['d_loss'], n_loss=losses['n_loss']))
                        # , n_loss=losses['n_loss']))
            # Tensorboard logger for scaler and images
            # info_terms = {'WGLoss':whole_loss.item(), 'ReconLoss':r_loss.item()}
            info_terms = {'WGLoss':whole_loss.item(), 'ReconLoss':r_loss.item(), "GANLoss":g_loss.item(), "DLoss":d_loss.item()}

            for tag, value in info_terms.items():
                tensorboardlogger.scalar_summary(tag, value, epoch*len(dataloader)+i)

            for tag, value in losses.items():
                tensorboardlogger.scalar_summary('avg_'+tag, value.avg, epoch*len(dataloader)+i)

            def img2photo(imgs):
                # return ((imgs+1)*127.5).detach().cpu().numpy()
                return ((imgs+1)*127.5).transpose(1,2).transpose(2,3).detach().cpu().numpy()
            # info = { 'train/ori_imgs':img2photo(imgs),
            #          'train/coarse_imgs':img2photo(coarse_imgs),
            #          'train/mixed':img2photo(mixed),
            #          'train/comp_imgs':img2photo(complete_imgs),
            info = {
                     'train/whole_imgs':img2photo(torch.cat([imgs * (1 - masks) + masks, refined, imgs * masks, complete_imgs, imgs], dim=3))
                     }

            for tag, images in info.items():
                tensorboardlogger.image_summary(tag, images, epoch*len(dataloader)+i)
        if (i+1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None:

            with torch.no_grad():
                validate(netG, netD, GANLoss, ReconLoss, DLoss, optG, optD, val_datas , epoch, device, batch_n=i)
            netG.train()
            # netD.train()
        end = time.time()
コード例 #16
0
def validate(netG,
             netD,
             GANLoss,
             ReconLoss,
             DLoss,
             optG,
             optD,
             dataloader,
             epoch,
             device=cuda0):
    """
    validate phase
    """
    netG.eval()
    netD.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        "d_loss": AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    for i, (imgs, masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        masks = masks['random_free_form']
        #masks = (masks > 0).type(torch.FloatTensor)

        imgs, masks = imgs.to(device), masks.to(device)
        imgs = (imgs / 127.5 - 1)
        # mask is 1 on masked region
        # forward
        coarse_imgs, recon_imgs = netG(imgs, masks)

        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)

        g_loss = GANLoss(pred_neg)

        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        # Update time recorder
        batch_time.update(time.time() - end)

        # Logger logging
        logger.info("Validation Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f},\t Whole Gen Loss:{whole_loss.val:.4f}\t,"
                    "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}"
                    .format(epoch, i, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                    ,g_loss=losses['g_loss'], d_loss=losses['d_loss']))

        if i * config.BATCH_SIZE < config.STATIC_VIEW_SIZE:

            def img2photo(imgs):
                return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                    2, 3).detach().cpu().numpy()

            info = {
                'val/ori_imgs':
                img2photo(imgs),
                'val/coarse_imgs':
                img2photo(coarse_imgs),
                'val/recon_imgs':
                img2photo(recon_imgs),
                'val/comp_imgs':
                img2photo(complete_imgs),
                'val/whole_imgs':
                img2photo(
                    torch.cat([imgs, coarse_imgs, recon_imgs, complete_imgs],
                              dim=3))
            }

            for tag, images in info.items():
                tensorboardlogger.image_summary(tag, images, i)
        end = time.time()
コード例 #17
0
def train(netG,
          netD,
          GANLoss,
          ReconLoss,
          DLoss,
          optG,
          optD,
          dataloader,
          epoch,
          device=cuda0,
          val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    netG.to(device)
    netD.to(device)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        'd_loss': AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    for i, (imgs, masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        masks = masks['random_free_form']

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()
        guide = []
        transform = transforms.Compose([transforms.ToPILImage()])
        for k in range(imgs.shape[0]):
            im = transform(imgs[k])
            im = np.array(im)
            # cv2.imwrite('test.jpg', im)

            im = cv2.Canny(image=im, threshold1=20, threshold2=220)
            # cv2.imwrite('test1.jpg', im)
            # exit(1)

            guide.append(im)
        guide = torch.FloatTensor(guide)
        guide = guide[:, None, :, :]
        imgs, masks, guide = imgs.to(device), masks.to(device), guide.to(
            device)

        imgs = (imgs / 127.5 - 1)
        # mask is 1 on masked region
        guide = guide / 255.0

        coarse_imgs, recon_imgs, attention = netG(imgs, masks, guide)
        # print(attention.size(), )
        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        d_loss.backward(retain_graph=True)

        optD.step()

        # Optimize Generator
        optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad()
        pred_neg = netD(neg_imgs)
        # pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
        g_loss = GANLoss(pred_neg)
        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))
        whole_loss.backward()

        optG.step()

        # Update time recorder
        batch_time.update(time.time() - end)

        if (i + 1) % config.SUMMARY_FREQ == 0:
            # Logger logging
            logger.info(
                "Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t,"
                "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" \
                .format(epoch, i + 1, len(dataloader), batch_time=batch_time, data_time=data_time,
                        whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                        , g_loss=losses['g_loss'], d_loss=losses['d_loss']))
            # Tensorboard logger for scaler and images
            info_terms = {
                'WGLoss': whole_loss.item(),
                'ReconLoss': r_loss.item(),
                "GANLoss": g_loss.item(),
                "DLoss": d_loss.item()
            }

            for tag, value in info_terms.items():
                tensorboardlogger.scalar_summary(tag, value,
                                                 epoch * len(dataloader) + i)

            for tag, value in losses.items():
                tensorboardlogger.scalar_summary('avg_' + tag, value.avg,
                                                 epoch * len(dataloader) + i)

            def img2photo(imgs):
                return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                    2, 3).detach().cpu().numpy()

            # info = { 'train/ori_imgs':img2photo(imgs),
            #          'train/coarse_imgs':img2photo(coarse_imgs),
            #          'train/recon_imgs':img2photo(recon_imgs),
            #          'train/comp_imgs':img2photo(complete_imgs),
            info = {
                'train/whole_imgs':
                img2photo(
                    torch.cat([
                        imgs * (1 - masks), coarse_imgs, recon_imgs, imgs,
                        complete_imgs
                    ],
                              dim=3))
            }

            for tag, images in info.items():
                tensorboardlogger.image_summary(tag, images,
                                                epoch * len(dataloader) + i)
        if (i + 1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None:
            validate(netG,
                     netD,
                     GANLoss,
                     ReconLoss,
                     DLoss,
                     optG,
                     optD,
                     val_datas,
                     epoch,
                     device,
                     batch_n=i)
            netG.train()
            netD.train()
        end = time.time()
コード例 #18
0
def train(nets, loss_terms, opts, dataloader, epoch, devices=(cuda0, cuda1), val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    netG, netD = nets["netG"], nets["netD"]
    GANLoss, ReconLoss, L1ReconLoss, DLoss = loss_terms["GANLoss"], loss_terms["ReconLoss"], loss_terms["L1ReconLoss"], loss_terms["DLoss"]
    optG, optD = opts["optG"], opts["optD"]
    device0, device1 = devices[0], devices[1]
    netG.to(device0)
    netD.to(device0)
    # maskNetD.to(device1)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {"g_loss":AverageMeter(), "r_loss":AverageMeter(), "r_ex_loss":AverageMeter(), "whole_loss":AverageMeter(), 'd_loss':AverageMeter(),}
              # 'mask_d_loss':AverageMeter(), 'mask_rec_loss':AverageMeter(),'mask_whole_loss':AverageMeter()}

    netG.train()
    netD.train()
    # maskNetD.train()
    end = time.time()
    for i, data in enumerate(dataloader):
        data_time.update(time.time() - end)
        imgs, img_exs, masks = data
        masks = masks['random_free_form']

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()

        imgs, img_exs, masks = imgs.to(device0), img_exs.to(device0), masks.to(device0)
        imgs = (imgs / 127.5 - 1)
        img_exs = (img_exs / 127.5 - 1)
        # mask is 1 on masked region
        coarse_imgs, recon_imgs, recon_ex_imgs = netG(imgs, img_exs, masks)

        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat([complete_imgs, masks, torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)
        #mask_pos_neg_imgs = torch.cat([imgs, complete_imgs], dim=0)

        # Discriminator Loss
        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        d_loss = DLoss(pred_pos, pred_neg)
        d_loss.backward(retain_graph=True)
        optD.step()


        # Mask Discriminator Loss
        # mask_pos_neg_imgs = mask_pos_neg_imgs.to(device1)
        # masks = masks.to(device1)
        # mask_pred_pos_neg = maskNetD(mask_pos_neg_imgs)
        # mask_pred_pos, mask_pred_neg = torch.chunk(mask_pred_pos_neg, 2, dim=0)
        # mask_d_loss = DLoss(mask_pred_pos*masks , mask_pred_neg*masks )
        # mask_rec_loss = L1ReconLoss(mask_pred_neg, masks, masks=masks)

        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        # losses['mask_d_loss'].update(mask_d_loss.item(), imgs.size(0))
        # losses['mask_rec_loss'].update(mask_rec_loss.item(), imgs.size(0))
        # mask_whole_loss = mask_rec_loss
        # losses['mask_whole_loss'].update(mask_whole_loss.item(), imgs.size(0))
        # mask_whole_loss.backward(retain_graph=True)
        # optMD.step()


        # Optimize Generator
        # masks = masks.to(device0)
        optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad(),# optMD.zero_grad(), maskNetD.zero_grad()
        pred_neg = netD(neg_imgs)
        #pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
        g_loss = GANLoss(pred_neg)
        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)
        r_ex_loss = L1ReconLoss(img_exs, recon_ex_imgs)

        whole_loss = g_loss + r_loss + r_ex_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['r_ex_loss'].update(r_ex_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))
        whole_loss.backward()

        optG.step()

        # Update time recorder
        batch_time.update(time.time() - end)

        if (i+1) % config.SUMMARY_FREQ == 0:
            # Logger logging
            logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t,"
                        "Recon Loss:{r_loss.val:.4f}, \t Ex Recon Loss:{r_ex_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}, " \
                        .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time \
                                ,whole_loss=losses['whole_loss'], r_loss=losses['r_loss'], r_ex_loss=losses['r_ex_loss'] \
                        ,g_loss=losses['g_loss'], d_loss=losses['d_loss']))
            # Tensorboard logger for scaler and images
            info_terms = {'WGLoss':whole_loss.item(), 'ReconLoss':r_loss.item(), "GANLoss":g_loss.item(), "DLoss":d_loss.item(), }

            for tag, value in info_terms.items():
                tensorboardlogger.scalar_summary(tag, value, epoch*len(dataloader)+i)

            for tag, value in losses.items():
                tensorboardlogger.scalar_summary('avg_'+tag, value.avg, epoch*len(dataloader)+i)

            def img2photo(imgs):
                return ((imgs+1)*127.5).transpose(1,2).transpose(2,3).detach().cpu().numpy()

            info = {
                     'train/whole_imgs':img2photo(torch.cat([imgs * (1 - masks), coarse_imgs, recon_imgs, imgs, complete_imgs], dim=3))
                     }

            for tag, images in info.items():
                tensorboardlogger.image_summary(tag, images, epoch*len(dataloader)+i)

        if (i+1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None:

            validate(nets, loss_terms, opts, val_datas , epoch, devices, batch_n=i)
            netG.train()
            netD.train()
            #maskNetD.train()
        end = time.time()
コード例 #19
0
def validate(netG,
             netD,
             GANLoss,
             ReconLoss,
             DLoss,
             optG,
             optD,
             dataloader,
             epoch,
             device=cuda0,
             batch_n="whole"):
    """
    validate phase
    """
    netG.to(device)
    netD.to(device)
    netG.eval()
    netD.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        "d_loss": AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    val_save_dir = os.path.join(
        result_dir, "val_{}_{}".format(
            epoch, batch_n if isinstance(batch_n, str) else batch_n + 1))
    val_save_real_dir = os.path.join(val_save_dir, "real")
    val_save_gen_dir = os.path.join(val_save_dir, "gen")
    val_save_inf_dir = os.path.join(val_save_dir, "inf")
    if not os.path.exists(val_save_real_dir):
        os.makedirs(val_save_real_dir)
        os.makedirs(val_save_gen_dir)
        os.makedirs(val_save_inf_dir)
    info = {}
    for i, (imgs, masks) in enumerate(dataloader):

        data_time.update(time.time() - end)
        masks = masks['val']
        # masks = (masks > 0).type(torch.FloatTensor)

        imgs, masks = imgs.to(device), masks.to(device)
        imgs = (imgs / 127.5 - 1)
        # mask is 1 on masked region
        # forward
        coarse_imgs, recon_imgs, attention = netG.forward(imgs, masks)

        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)

        g_loss = GANLoss(pred_neg)

        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        # Update time recorder
        batch_time.update(time.time() - end)

        # Logger logging

        if i + 1 < config.STATIC_VIEW_SIZE:

            def img2photo(imgs):
                return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                    2, 3).detach().cpu().numpy()

            # info = { 'val/ori_imgs':img2photo(imgs),
            #          'val/coarse_imgs':img2photo(coarse_imgs),
            #          'val/recon_imgs':img2photo(recon_imgs),
            #          'val/comp_imgs':img2photo(complete_imgs),
            info['val/whole_imgs/{}'.format(i)] = img2photo(
                torch.cat([
                    imgs *
                    (1 - masks), coarse_imgs, recon_imgs, imgs, complete_imgs
                ],
                          dim=3))

        else:
            logger.info(
                "Validation Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f},\t Whole Gen Loss:{whole_loss.val:.4f}\t,"
                "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}"
                .format(epoch, i + 1, len(dataloader), batch_time=batch_time, data_time=data_time,
                        whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                        , g_loss=losses['g_loss'], d_loss=losses['d_loss']))
            j = 0
            for tag, images in info.items():
                h, w = images.shape[1], images.shape[2] // 5
                for val_img in images:
                    real_img = val_img[:, (3 * w):(4 * w), :]
                    gen_img = val_img[:, (4 * w):, :]
                    real_img = Image.fromarray(real_img.astype(np.uint8))
                    gen_img = Image.fromarray(gen_img.astype(np.uint8))
                    real_img.save(
                        os.path.join(val_save_real_dir, "{}.png".format(j)))
                    gen_img.save(
                        os.path.join(val_save_gen_dir, "{}.png".format(j)))
                    j += 1
                tensorboardlogger.image_summary(tag, images, epoch)
            path1, path2 = val_save_real_dir, val_save_gen_dir
            fid_score = metrics['fid']([path1, path2], cuda=False)
            ssim_score = metrics['ssim']([path1, path2])
            tensorboardlogger.scalar_summary('val/fid', fid_score.item(),
                                             epoch * len(dataloader) + i)
            tensorboardlogger.scalar_summary('val/ssim', ssim_score.item(),
                                             epoch * len(dataloader) + i)
            break

        end = time.time()
コード例 #20
0
def train(netG, netD, GANLoss, ReconLoss, DLoss, optG, optD, dataloader, epoch,
          img_size, loss_writer):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """

    netG.train()
    netD.train()
    for i, (imgs, masks, _, _, _) in enumerate(dataloader):
        # masks = masks['val']

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()

        align_corners = True
        # imgs = F.interpolate(imgs, img_size, mode='bicubic', align_corners=align_corners)
        # imgs = imgs.clamp(min=-1, max=1)
        # masks = F.interpolate(masks, img_size, mode='bicubic', align_corners=align_corners)
        # masks = (masks > 0).type(torch.FloatTensor)

        imgs, masks = imgs.cuda(), masks.cuda()

        coarse_imgs, recon_imgs = netG(imgs, masks)
        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        d_loss = DLoss(pred_pos, pred_neg)
        # losses['d_loss'].update(d_loss.item(), imgs.size(0))
        d_loss_val = d_loss.item()
        d_loss.backward(retain_graph=True)

        optD.step()

        # Optimize Generator
        optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad()
        pred_neg = netD(neg_imgs)
        #pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
        g_loss = GANLoss(pred_neg)
        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        # losses['g_loss'].update(g_loss.item(), imgs.size(0))
        # losses['r_loss'].update(r_loss.item(), imgs.size(0))
        # losses['whole_loss'].update(whole_loss.item(), imgs.size(0))
        g_loss_val = g_loss.item()
        r_loss_val = r_loss.item()
        whole_loss_val = whole_loss.item()
        whole_loss.backward()

        optG.step()
        if (i + 1) % 25 == 0:
            print("Epoch {0} [{1}/{2}]:   Whole Loss:{whole_loss:.4f}   "
                        "Recon Loss:{r_loss:.4f}   GAN Loss:{g_loss:.4f}   D Loss:{d_loss:.4f}" \
                        .format(epoch, i+1, len(dataloader), whole_loss=whole_loss_val, r_loss=r_loss_val \
                        ,g_loss=g_loss_val, d_loss=d_loss_val))
        loss_writer.writerow(
            [epoch, whole_loss_val, r_loss_val, g_loss_val, d_loss_val])