Ejemplo n.º 1
0
def validate(nets,
             loss_terms,
             opts,
             dataloader,
             epoch,
             network_type,
             devices=(cuda0, cuda1),
             batch_n="whole_test_show"):
    """
    validate phase
    """
    netD, netG = nets["netD"], nets["netG"]
    ReconLoss, DLoss, PercLoss, GANLoss, StyleLoss = loss_terms[
        'ReconLoss'], loss_terms['DLoss'], loss_terms["PercLoss"], loss_terms[
            "GANLoss"], loss_terms["StyleLoss"]
    optG, optD = opts['optG'], opts['optD']
    device0, device1 = devices
    netG.to(device0)
    netD.to(device0)
    netG.eval()
    netD.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "p_loss": AverageMeter(),
        "s_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        "d_loss": AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    val_save_dir = os.path.join(
        result_dir, "val_{}_{}".format(
            epoch, batch_n if isinstance(batch_n, str) else batch_n + 1))
    val_save_real_dir = os.path.join(val_save_dir, "real")
    val_save_gen_dir = os.path.join(val_save_dir, "gen")
    val_save_comp_dir = os.path.join(val_save_dir, "comp")
    for size in SIZES_TAGS:
        if not os.path.exists(os.path.join(val_save_real_dir, size)):
            os.makedirs(os.path.join(val_save_real_dir, size))
        if not os.path.exists(os.path.join(val_save_gen_dir, size)):
            os.makedirs(os.path.join(val_save_gen_dir, size))
        if not os.path.exists(os.path.join(val_save_comp_dir, size)):
            os.makedirs(os.path.join(val_save_comp_dir, size))
    info = {}
    t = 0
    for i, (ori_imgs, ori_masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        pre_imgs = ori_imgs
        pre_complete_imgs = (pre_imgs / 127.5 - 1)

        for s_i, size in enumerate(TRAIN_SIZES):

            masks = ori_masks['val']
            masks = F.interpolate(masks, size)
            masks = (masks > 0).type(torch.FloatTensor)
            imgs = F.interpolate(ori_imgs, size)
            if imgs.size(1) != 3:
                print(t, imgs.size())
            pre_inter_imgs = F.interpolate(pre_complete_imgs, size)

            imgs, masks, pre_complete_imgs, pre_inter_imgs = imgs.to(
                device0), masks.to(device0), pre_complete_imgs.to(
                    device0), pre_inter_imgs.to(device0)
            # masks = (masks > 0).type(torch.FloatTensor)

            # imgs, masks = imgs.to(device), masks.to(device)
            imgs = (imgs / 127.5 - 1)
            # mask is 1 on masked region
            # forward
            if network_type == 'l2h_unet':
                recon_imgs = netG(imgs, masks, pre_complete_imgs,
                                  pre_inter_imgs, size)
            elif network_type == 'l2h_gated':
                recon_imgs = netG(imgs, masks, pre_inter_imgs)
            elif network_type == 'sa_gated':
                recon_imgs, _ = netG(imgs, masks)
            complete_imgs = recon_imgs * masks + imgs * (1 - masks)

            pos_imgs = torch.cat(
                [imgs, masks, torch.full_like(masks, 1.)], dim=1)
            neg_imgs = torch.cat(
                [recon_imgs, masks,
                 torch.full_like(masks, 1.)], dim=1)
            pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

            pred_pos_neg = netD(pos_neg_imgs)
            pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)

            g_loss = GANLoss(pred_neg)

            r_loss = ReconLoss(imgs, recon_imgs, recon_imgs, masks)

            imgs, recon_imgs, complete_imgs = imgs.to(device1), recon_imgs.to(
                device1), complete_imgs.to(device1)
            p_loss = PercLoss(imgs, recon_imgs) + PercLoss(imgs, complete_imgs)
            s_loss = StyleLoss(imgs, recon_imgs) + StyleLoss(
                imgs, complete_imgs)
            p_loss, s_loss = p_loss.to(device0), s_loss.to(device0)
            imgs, recon_imgs, complete_imgs = imgs.to(device0), recon_imgs.to(
                device0), complete_imgs.to(device0)

            whole_loss = r_loss + p_loss  # g_loss + r_loss

            # Update the recorder for losses
            losses['g_loss'].update(g_loss.item(), imgs.size(0))
            losses['r_loss'].update(r_loss.item(), imgs.size(0))
            losses['p_loss'].update(p_loss.item(), imgs.size(0))
            losses['s_loss'].update(s_loss.item(), imgs.size(0))
            losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

            d_loss = DLoss(pred_pos, pred_neg)
            losses['d_loss'].update(d_loss.item(), imgs.size(0))
            pre_complete_imgs = complete_imgs
            # Update time recorder
            batch_time.update(time.time() - end)

            # Logger logging

            # if t < config.STATIC_VIEW_SIZE:
            print(i, size)
            real_img = img2photo(imgs)
            gen_img = img2photo(recon_imgs)
            comp_img = img2photo(complete_imgs)

            real_img = Image.fromarray(real_img[0].astype(np.uint8))
            gen_img = Image.fromarray(gen_img[0].astype(np.uint8))
            comp_img = Image.fromarray(comp_img[0].astype(np.uint8))
            real_img.save(
                os.path.join(val_save_real_dir, SIZES_TAGS[s_i],
                             "{}.png".format(i)))
            gen_img.save(
                os.path.join(val_save_gen_dir, SIZES_TAGS[s_i],
                             "{}.png".format(i)))
            comp_img.save(
                os.path.join(val_save_comp_dir, SIZES_TAGS[s_i],
                             "{}.png".format(i)))

            end = time.time()
def main():
    logger_init()
    dataset_type = config.DATASET
    batch_size = config.BATCH_SIZE

    # Dataset setting
    logger.info("Initialize the dataset...")
    train_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][0],\
                                      {mask_type:config.DATA_FLIST[config.MASKDATASET][mask_type][0] for mask_type in config.MASK_TYPES}, \
                                      resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \
                                      random_bbox_margin=config.RANDOM_BBOX_MARGIN,
                                      random_ff_setting=config.RANDOM_FF_SETTING)
    train_loader = train_dataset.loader(batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=16)

    val_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][1],\
                                    {mask_type:config.DATA_FLIST[config.MASKDATASET][mask_type][1] for mask_type in ('val',)}, \
                                    resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \
                                    random_bbox_margin=config.RANDOM_BBOX_MARGIN,
                                    random_ff_setting=config.RANDOM_FF_SETTING)
    val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1)
    ### Generate a new val data
    val_datas = []
    j = 0
    for i, data in enumerate(val_loader):
        if j < config.STATIC_VIEW_SIZE:
            imgs = data[0]
            if imgs.size(1) == 3:
                val_datas.append(data)
                j += 1
        else:
            break
    #val_datas = [(imgs, masks) for imgs, masks in val_loader]

    val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1)
    logger.info("Finish the dataset initialization.")

    # Define the Network Structure
    logger.info("Define the Network Structure and Losses")
    netG = InpaintRUNNet(cuda0, n_in_channel=config.N_CHANNEL)
    netD = InpaintSADirciminator()
    netVGG = vgg16_bn(pretrained=True)
    sr_args = SRArgs(config.GPU_IDS[0])
    netSR = sr_model.Model(sr_args, sr_util.checkpoint(sr_args))

    if config.MODEL_RESTORE != '':
        whole_model_path = 'model_logs/{}'.format(config.MODEL_RESTORE)
        nets = torch.load(whole_model_path)
        netG_state_dict, netD_state_dict = nets['netG_state_dict'], nets[
            'netD_state_dict']
        netG.load_state_dict(netG_state_dict)
        netD.load_state_dict(netD_state_dict)
        logger.info("Loading pretrained models from {} ...".format(
            config.MODEL_RESTORE))

    # Define loss
    recon_loss = ReconLoss(*(config.L1_LOSS_ALPHA))
    gan_loss = SNGenLoss(config.GAN_LOSS_ALPHA)
    perc_loss = PerceptualLoss(weight=config.PERC_LOSS_ALPHA,
                               feat_extractors=netVGG.to(cuda1))
    style_loss = StyleLoss(weight=config.STYLE_LOSS_ALPHA,
                           feat_extractors=netVGG.to(cuda1))
    dis_loss = SNDisLoss()
    lr, decay = config.LEARNING_RATE, config.WEIGHT_DECAY
    optG = torch.optim.Adam(netG.parameters(), lr=lr, weight_decay=decay)
    optD = torch.optim.Adam(netD.parameters(), lr=4 * lr, weight_decay=decay)

    nets = {"netG": netG, "netD": netD, "vgg": netVGG, "netSR": netSR}

    losses = {
        "GANLoss": gan_loss,
        "ReconLoss": recon_loss,
        "StyleLoss": style_loss,
        "DLoss": dis_loss,
        "PercLoss": perc_loss
    }

    opts = {
        "optG": optG,
        "optD": optD,
    }

    logger.info("Finish Define the Network Structure and Losses")

    # Start Training
    logger.info("Start Training...")
    epoch = 50

    for i in range(epoch):
        #validate(netG, netD, gan_loss, recon_loss, dis_loss, optG, optD, val_loader, i, device=cuda0)

        #train data
        train(nets,
              losses,
              opts,
              train_loader,
              i,
              devices=(cuda0, cuda1),
              val_datas=val_datas)

        # validate
        validate(nets, losses, opts, val_datas, i, devices=(cuda0, cuda1))

        saved_model = {
            'epoch': i + 1,
            'netG_state_dict': netG.to(cpu0).state_dict(),
            'netD_state_dict': netD.to(cpu0).state_dict(),
            # 'optG' : optG.state_dict(),
            # 'optD' : optD.state_dict()
        }
        torch.save(saved_model,
                   '{}/epoch_{}_ckpt.pth.tar'.format(log_dir, i + 1))
        torch.save(saved_model,
                   '{}/latest_ckpt.pth.tar'.format(log_dir, i + 1))
Ejemplo n.º 3
0
def main():
    logger_init()
    dataset_type = config.DATASET
    batch_size = config.BATCH_SIZE

    # Dataset setting
    logger.info("Initialize the dataset...")
    val_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][1], \
                                 {mask_type: config.DATA_FLIST[config.MASKDATASET][mask_type][1] for mask_type in
                                  ('val',)}, \
                                 resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \
                                 random_bbox_margin=config.RANDOM_BBOX_MARGIN,
                                 random_ff_setting=config.RANDOM_FF_SETTING)
    val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1)
    # print(len(val_loader))

    ### Generate a new val data

    logger.info("Finish the dataset initialization.")

    # Define the Network Structure
    logger.info("Define the Network Structure and Losses")
    whole_model_path = 'model_logs/{}'.format(config.MODEL_RESTORE)
    nets = torch.load(whole_model_path)
    netG_state_dict, netD_state_dict = nets['netG_state_dict'], nets[
        'netD_state_dict']
    if config.NETWORK_TYPE == "l2h_unet":
        netG = InpaintRUNNet(n_in_channel=config.N_CHANNEL)
        netG.load_state_dict(netG_state_dict)

    elif config.NETWORK_TYPE == 'sa_gated':
        netG = InpaintSANet()
        load_consistent_state_dict(netG_state_dict, netG)
        # netG.load_state_dict(netG_state_dict)

    netD = InpaintSADirciminator()
    netVGG = vgg16_bn(pretrained=True)

    # netD.load_state_dict(netD_state_dict)
    logger.info("Loading pretrained models from {} ...".format(
        config.MODEL_RESTORE))

    # Define loss
    recon_loss = ReconLoss(*(config.L1_LOSS_ALPHA))
    gan_loss = SNGenLoss(config.GAN_LOSS_ALPHA)
    perc_loss = PerceptualLoss(weight=config.PERC_LOSS_ALPHA,
                               feat_extractors=netVGG.to(cuda1))
    style_loss = StyleLoss(weight=config.STYLE_LOSS_ALPHA,
                           feat_extractors=netVGG.to(cuda1))
    dis_loss = SNDisLoss()
    lr, decay = config.LEARNING_RATE, config.WEIGHT_DECAY
    optG = torch.optim.Adam(netG.parameters(), lr=lr, weight_decay=decay)
    optD = torch.optim.Adam(netD.parameters(), lr=4 * lr, weight_decay=decay)
    nets = {"netG": netG, "netD": netD, "vgg": netVGG}

    losses = {
        "GANLoss": gan_loss,
        "ReconLoss": recon_loss,
        "StyleLoss": style_loss,
        "DLoss": dis_loss,
        "PercLoss": perc_loss
    }
    opts = {
        "optG": optG,
        "optD": optD,
    }
    logger.info("Finish Define the Network Structure and Losses")

    # Start Training
    logger.info("Start Validation")

    validate(nets,
             losses,
             opts,
             val_loader,
             0,
             config.NETWORK_TYPE,
             devices=(cuda0, cuda1))
Ejemplo n.º 4
0
feature_extractor = vgg16_bn()
feature_extractor.eval()
feature_extractor.cuda()
if opt.multigpu:
    feature_extractor = nn.DataParallel(feature_extractor)
recon_loss = TwoReconLoss(0.1, 0.05, 0.1, 0.05)
#recon_loss = ReconLoss(1, 1)
#ssim_loss = SSIM_Loss(1)
gan_loss = SNGenLoss(0.001)
dis_loss = SNDisLoss()
perceptual_loss = PerceptualLoss(weight=0.1,
                                 layers=[0, 5, 12, 22],
                                 feat_extractors=feature_extractor)
style_loss = StyleLoss(weight=250,
                       layers=[0, 5, 12, 22],
                       feat_extractors=feature_extractor)
tv_loss = TVLoss(weight=0.1)
optG = torch.optim.Adam(netG.parameters(),
                        lr=opt.lr_g,
                        betas=(opt.b1, opt.b2),
                        weight_decay=opt.weight_decay)
optD = torch.optim.Adam(netD.parameters(),
                        lr=opt.lr_d,
                        betas=(opt.b1, opt.b2),
                        weight_decay=opt.weight_decay)


def img2photo(imgs):
    return ((imgs + 1) * 127.5).transpose(0, 1).transpose(
        1, 2).detach().cpu().numpy()