Ejemplo n.º 1
0
    def __init__(self, config):
        super().__init__()
        self.config = config
        if config['image_size'][0] != config['image_size'][1]:
            raise Exception('Non-square images are not supported yet.')

        device = config['device']
        self.downsampler_1024_256 = BicubicDownSample(4)
        self.downsampler_1024_image = BicubicDownSample(
            1024 // config['image_size'][0])
        self.downsampler_image_256 = BicubicDownSample(
            config['image_size'][0] // 256)

        # Load models and pre-trained weights
        gen = Generator(1024, 512, 8)
        gen.load_state_dict(torch.load(config["ckpt"])["g_ema"], strict=False)
        gen.eval()
        self.gen = gen.to(device)
        self.gen.start_layer = config['start_layer']
        self.gen.end_layer = config['end_layer']
        self.mpl = MappingProxy(torch.load('gaussian_fit.pt'))
        self.percept = lpips.PerceptualLoss(model="net-lin",
                                            net="vgg",
                                            use_gpu=device.startswith("cuda"))
        self.init_state()
Ejemplo n.º 2
0
    def __init__(self, args):
        self.args = args

        self.init_vars()

        self.percept = lpips.PerceptualLoss(
            model='net-lin', net='vgg', use_gpu=self.args.device.startswith('cuda')
        )
Ejemplo n.º 3
0
 def __init__(self,
              cuda,
              des="Learned Perceptual Image Patch Similarity",
              version="0.1"):
     self.des = des
     self.version = version
     self.model = lpips.PerceptualLoss(model='net-lin',
                                       net='alex',
                                       use_gpu=cuda)
Ejemplo n.º 4
0
    def __init__(self, config):
        super().__init__()
        self.config = config
        if config['image_size'][0] != config['image_size'][1]:
            raise Exception('Non-square images are not supported yet.')
        self.reconstruction = config["reconstruction"]
        self.steps = config["steps"]
        self.lr = config['lr']
        # self.mse_record = []

        transform = get_transformation(config['image_size'][0])
        images = []

        for imgfile in config['input_files']:
            images.append(transform(Image.open(imgfile).convert("RGB")))

        self.images = torch.stack(images, 0).to(config['device'])
        self.downsampler_256_image = BicubicDownSample(256 //
                                                       config['image_size'][0])

        biggan = BigGAN(config)
        biggan.load_pretrained()
        self.generator = biggan.generator
        self.generator.eval()
        self.generator.to(config['device'])

        (self.z, y) = prepare_z_y(self.images.shape[0],
                                  self.generator.dim_z,
                                  config["n_classes"],
                                  device=config["device"],
                                  fp16=config["G_fp16"],
                                  z_var=config["z_var"],
                                  target=config["target"],
                                  range=config["range"])
        self.y = self.generator.shared(y)

        self.z.requires_grad = True

        self.perceptual_loss = lpips.PerceptualLoss(
            model="net-lin",
            net="vgg",
            use_gpu=config['device'].startswith("cuda"))
Ejemplo n.º 5
0
import sys
from os.path import join
from time import time
import matplotlib.pylab as plt
from matplotlib import cm
import torch
import numpy as np
sys.path.append("E:\Github_Projects\Visual_Neuro_InSilico_Exp")
sys.path.append("D:\Github\Visual_Neuro_InSilico_Exp")
import os
# os.system(r"'C:\Program Files (x86)\Microsoft Visual Studio\2019\Professional\VC\Auxiliary\Build\vcvars64.bat'")
import lpips
try:
    ImDist = lpips.LPIPS(net="squeeze").cuda()
except:
    ImDist = lpips.PerceptualLoss(net="squeeze").cuda()
from GAN_hessian_compute import hessian_compute, get_full_hessian
from hessian_analysis_tools import compute_vector_hess_corr, compute_hess_corr, plot_layer_consistency_mat
from GAN_utils import loadStyleGAN2, StyleGAN2_wrapper
#%%
modelname = "ffhq-256-config-e-003810"  # 109 sec
SGAN = loadStyleGAN2(modelname+".pt", size=256, channel_multiplier=1)  # 491 sec per BP
#%%
L2dist_col = []
def Hess_hook(module, fea_in, fea_out):
    print("hooker on %s"%module.__class__)
    ref_feat = fea_out.detach().clone()
    ref_feat.requires_grad_(False)
    L2dist = torch.pow(fea_out - ref_feat, 2).sum()
    L2dist_col.append(L2dist)
    return None
Ejemplo n.º 6
0
        return image

    def saveImage(self, image, name):
        ####save image
        utils.save_image(
            image,
            name + ".png",
            nrow=int(1**0.5),
            normalize=True,
            range=(-1, 1),
        )
        ####save image ends


if __name__ == "__main__":
    loss_fn_vgg = lpips.PerceptualLoss(net='vgg')
    with torch.no_grad():
        ckpt = "/common/users/sm2322/MS-Thesis/GAN-Thesis-Work-Remote/styleGAN2-AE-Ligong-Remote/trainedPts/gan/168000.pt"
        ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage)
        device = "cuda"
        generator = Generator(128, 512, 8, 2).to(device)
        generator.load_state_dict(ckpt["g_ema"])
        generator.eval()
        # print(list(generator.children()))
        z0, z1 = torch.randn([1 * 2, 512], device=device).chunk(2)
        # sample_z = torch.randn(1, 512, device=device)

        genBlocks = GeneratorBlocks(generator)
        w0, stackW0 = genBlocks.G_Mapper(z0)
        w1, stackW1 = genBlocks.G_Mapper(z1)
Ejemplo n.º 7
0
            args.start_iter = int(ckpt_name.split('.')[0].split('-')[-1])

        except ValueError:
            pass

        generator.load_state_dict(ckpt['g'])
        discriminator.load_my_state_dict(ckpt['d'])
        g_ema.load_state_dict(ckpt['g_ema'])

        del ckpt
        torch.cuda.empty_cache()

    percept = None
    if args.with_gan_feat_loss:
        import lpips
        percept = lpips.PerceptualLoss(model='net-lin', net='vgg')

    if args.distributed:
        generator = nn.parallel.DistributedDataParallel(
            generator,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
            # find_unused_parameters=True,
        )

        discriminator = nn.parallel.DistributedDataParallel(
            discriminator,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
Ejemplo n.º 8
0
def train():
    from benchmark import calc_fid, extract_feature_from_generator_fn, load_patched_inception_v3, real_image_loader, image_generator, image_generator_perm
    import lpips

    from config import IM_SIZE_GAN, BATCH_SIZE_GAN, NFC, NBR_CLS, DATALOADER_WORKERS, EPOCH_GAN, ITERATION_AE, GAN_CKECKPOINT
    from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, LOG_INTERVAL, SAVE_FOLDER, TRIAL_NAME, DATA_NAME, MULTI_GPU
    from config import FID_INTERVAL, FID_BATCH_NBR, PRETRAINED_AE_PATH
    from config import data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3

    real_features = None
    inception = load_patched_inception_v3().cuda()
    inception.eval()

    percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)

    saved_image_folder = saved_model_folder = None
    log_file_path = None
    if saved_image_folder is None:
        saved_image_folder, saved_model_folder = make_folders(
            SAVE_FOLDER, 'GAN_' + TRIAL_NAME)
        log_file_path = saved_image_folder + '/../gan_log.txt'
        log_file = open(log_file_path, 'w')
        log_file.close()

    dataset = PairedMultiDataset(data_root_colorful,
                                 data_root_sketch_1,
                                 data_root_sketch_2,
                                 data_root_sketch_3,
                                 im_size=IM_SIZE_GAN,
                                 rand_crop=True)
    print('the dataset contains %d images.' % len(dataset))
    dataloader = iter(
        DataLoader(dataset,
                   BATCH_SIZE_GAN,
                   sampler=InfiniteSamplerWrapper(dataset),
                   num_workers=DATALOADER_WORKERS,
                   pin_memory=True))

    from datasets import ImageFolder
    from datasets import trans_maker_augment as trans_maker

    dataset_rgb = ImageFolder(data_root_colorful, trans_maker(512))
    dataset_skt = ImageFolder(data_root_sketch_3, trans_maker(512))

    net_ae = AE(nfc=NFC, nbr_cls=NBR_CLS)

    if PRETRAINED_AE_PATH is None:
        PRETRAINED_AE_PATH = 'train_results/' + 'AE_' + TRIAL_NAME + '/models/%d.pth' % ITERATION_AE
    else:
        from config import PRETRAINED_AE_ITER
        PRETRAINED_AE_PATH = PRETRAINED_AE_PATH + '/models/%d.pth' % PRETRAINED_AE_ITER

    net_ae.load_state_dicts(PRETRAINED_AE_PATH)
    net_ae.cuda()
    net_ae.eval()

    RefineGenerator = None
    if DATA_NAME == 'celeba':
        from models import RefineGenerator_face as RefineGenerator
    elif DATA_NAME == 'art' or DATA_NAME == 'shoe':
        from models import RefineGenerator_art as RefineGenerator
    net_ig = RefineGenerator(nfc=NFC, im_size=IM_SIZE_GAN).cuda()
    net_id = Discriminator(nc=3).cuda(
    )  # we use the patch_gan, so the im_size for D should be 512 even if training image size is 1024

    if MULTI_GPU:
        net_ae = nn.DataParallel(net_ae)
        net_ig = nn.DataParallel(net_ig)
        net_id = nn.DataParallel(net_id)

    net_ig_ema = copy_G_params(net_ig)

    opt_ig = optim.Adam(net_ig.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_id = optim.Adam(net_id.parameters(), lr=2e-4, betas=(0.5, 0.999))

    if GAN_CKECKPOINT is not None:
        ckpt = torch.load(GAN_CKECKPOINT)
        net_ig.load_state_dict(ckpt['ig'])
        net_id.load_state_dict(ckpt['id'])
        net_ig_ema = ckpt['ig_ema']
        opt_ig.load_state_dict(ckpt['opt_ig'])
        opt_id.load_state_dict(ckpt['opt_id'])

    ## create a log file
    losses_g_img = AverageMeter()
    losses_d_img = AverageMeter()
    losses_mse = AverageMeter()
    losses_rec_s = AverageMeter()

    losses_rec_ae = AverageMeter()

    fixed_skt = fixed_rgb = fixed_perm = None

    fid = [[0, 0]]

    for epoch in range(EPOCH_GAN):
        for iteration in tqdm(range(10000)):
            rgb_img, skt_img_1, skt_img_2, skt_img_3 = next(dataloader)

            rgb_img = rgb_img.cuda()

            rd = random.randint(0, 3)
            if rd == 0:
                skt_img = skt_img_1.cuda()
            elif rd == 1:
                skt_img = skt_img_2.cuda()
            else:
                skt_img = skt_img_3.cuda()

            if iteration == 0:
                fixed_skt = skt_img_3[:8].clone().cuda()
                fixed_rgb = rgb_img[:8].clone()
                fixed_perm = true_randperm(fixed_rgb.shape[0], 'cuda')

            ### 1. train D
            gimg_ae, style_feats = net_ae(skt_img, rgb_img)
            g_image = net_ig(gimg_ae, style_feats)

            pred_r = net_id(rgb_img)
            pred_f = net_id(g_image.detach())

            loss_d = d_hinge_loss(pred_r, pred_f)

            net_id.zero_grad()
            loss_d.backward()
            opt_id.step()

            loss_rec_ae = F.mse_loss(gimg_ae, rgb_img) + F.l1_loss(
                gimg_ae, rgb_img)
            losses_rec_ae.update(loss_rec_ae.item(), BATCH_SIZE_GAN)

            ### 2. train G
            pred_g = net_id(g_image)
            loss_g = g_hinge_loss(pred_g)

            if DATA_NAME == 'shoe':
                loss_mse = 10 * (F.l1_loss(g_image, rgb_img) +
                                 F.mse_loss(g_image, rgb_img))
            else:
                loss_mse = 10 * percept(
                    F.adaptive_avg_pool2d(g_image, output_size=256),
                    F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum()
            losses_mse.update(loss_mse.item() / BATCH_SIZE_GAN, BATCH_SIZE_GAN)

            loss_all = loss_g + loss_mse

            if DATA_NAME == 'shoe':
                ### the grey image reconstruction
                perm = true_randperm(BATCH_SIZE_GAN)
                img_ae_perm, style_feats_perm = net_ae(skt_img, rgb_img[perm])

                gimg_grey = net_ig(img_ae_perm, style_feats_perm)
                gimg_grey = gimg_grey.mean(dim=1, keepdim=True)
                real_grey = rgb_img.mean(dim=1, keepdim=True)
                loss_rec_grey = F.mse_loss(gimg_grey, real_grey)
                loss_all += 10 * loss_rec_grey

            net_ig.zero_grad()
            loss_all.backward()
            opt_ig.step()

            for p, avg_p in zip(net_ig.parameters(), net_ig_ema):
                avg_p.mul_(0.999).add_(p.data, alpha=0.001)

            ### 3. logging
            losses_g_img.update(pred_g.mean().item(), BATCH_SIZE_GAN)
            losses_d_img.update(pred_r.mean().item(), BATCH_SIZE_GAN)

            if iteration % SAVE_IMAGE_INTERVAL == 0:  #show the current images
                with torch.no_grad():

                    backup_para_g = copy_G_params(net_ig)
                    load_params(net_ig, net_ig_ema)

                    gimg_ae, style_feats = net_ae(fixed_skt, fixed_rgb)
                    gmatch = net_ig(gimg_ae, style_feats)

                    gimg_ae_perm, style_feats = net_ae(fixed_skt,
                                                       fixed_rgb[fixed_perm])
                    gmismatch = net_ig(gimg_ae_perm, style_feats)

                    gimg = torch.cat([
                        F.interpolate(fixed_rgb, IM_SIZE_GAN),
                        F.interpolate(fixed_skt.repeat(1, 3, 1, 1),
                                      IM_SIZE_GAN), gmatch,
                        F.interpolate(gimg_ae, IM_SIZE_GAN), gmismatch,
                        F.interpolate(gimg_ae_perm, IM_SIZE_GAN)
                    ])

                    vutils.save_image(
                        gimg,
                        f'{saved_image_folder}/img_iter_{epoch}_{iteration}.jpg',
                        normalize=True,
                        range=(-1, 1))
                    del gimg

                    make_matrix(
                        dataset_rgb, dataset_skt, net_ae, net_ig, 5,
                        f'{saved_image_folder}/img_iter_{epoch}_{iteration}_matrix.jpg'
                    )

                    load_params(net_ig, backup_para_g)

            if iteration % LOG_INTERVAL == 0:
                log_msg = 'Iter: [{0}/{1}] G: {losses_g_img.avg:.4f}  D: {losses_d_img.avg:.4f}  MSE: {losses_mse.avg:.4f}  Rec: {losses_rec_s.avg:.5f}  FID: {fid:.4f}'.format(
                    epoch,
                    iteration,
                    losses_g_img=losses_g_img,
                    losses_d_img=losses_d_img,
                    losses_mse=losses_mse,
                    losses_rec_s=losses_rec_s,
                    fid=fid[-1][0])

                print(log_msg)
                print('%.5f' % (losses_rec_ae.avg))

                if log_file_path is not None:
                    log_file = open(log_file_path, 'a')
                    log_file.write(log_msg + '\n')
                    log_file.close()

                losses_g_img.reset()
                losses_d_img.reset()
                losses_mse.reset()
                losses_rec_s.reset()
                losses_rec_ae.reset()

            if iteration % SAVE_MODEL_INTERVAL == 0 or iteration + 1 == 10000:
                print('Saving history model')
                torch.save(
                    {
                        'ig': net_ig.state_dict(),
                        'id': net_id.state_dict(),
                        'ae': net_ae.state_dict(),
                        'ig_ema': net_ig_ema,
                        'opt_ig': opt_ig.state_dict(),
                        'opt_id': opt_id.state_dict(),
                    }, '%s/%d.pth' % (saved_model_folder, epoch))

            if iteration % FID_INTERVAL == 0 and iteration > 1:
                print("calculating FID ...")
                fid_batch_images = FID_BATCH_NBR
                if real_features is None:
                    if os.path.exists('%s_fid_feats.npy' % (DATA_NAME)):
                        real_features = pickle.load(
                            open('%s_fid_feats.npy' % (DATA_NAME), 'rb'))
                    else:
                        real_features = extract_feature_from_generator_fn(
                            real_image_loader(dataloader,
                                              n_batches=fid_batch_images),
                            inception)
                        real_mean = np.mean(real_features, 0)
                        real_cov = np.cov(real_features, rowvar=False)
                        pickle.dump(
                            {
                                'feats': real_features,
                                'mean': real_mean,
                                'cov': real_cov
                            }, open('%s_fid_feats.npy' % (DATA_NAME), 'wb'))
                        real_features = pickle.load(
                            open('%s_fid_feats.npy' % (DATA_NAME), 'rb'))

                sample_features = extract_feature_from_generator_fn(
                    image_generator(dataset,
                                    net_ae,
                                    net_ig,
                                    n_batches=fid_batch_images),
                    inception,
                    total=fid_batch_images)
                cur_fid = calc_fid(sample_features,
                                   real_mean=real_features['mean'],
                                   real_cov=real_features['cov'])
                sample_features_perm = extract_feature_from_generator_fn(
                    image_generator_perm(dataset,
                                         net_ae,
                                         net_ig,
                                         n_batches=fid_batch_images),
                    inception,
                    total=fid_batch_images)
                cur_fid_perm = calc_fid(sample_features_perm,
                                        real_mean=real_features['mean'],
                                        real_cov=real_features['cov'])

                fid.append([cur_fid, cur_fid_perm])
                print('fid:', fid)
                if log_file_path is not None:
                    log_file = open(log_file_path, 'a')
                    log_msg = 'fid: %.5f, %.5f' % (fid[-1][0], fid[-1][1])
                    log_file.write(log_msg + '\n')
                    log_file.close()
Ejemplo n.º 9
0
def lerp(a, b, t):
    return a + (b - a) * t

    latent_dim = 512

    ckpt = torch.load(args.ckpt)

    g = Generator(args.size, latent_dim, 8).to(device)
    g.load_state_dict(ckpt['g_ema'])
    g.eval()

    percept = lpips.PerceptualLoss(model='net-lin',
                                   net='vgg',
                                   use_gpu=device.startswith('cuda'))

    distances = []

    n_batch = args.n_sample // args.batch
    resid = args.n_sample - (n_batch * args.batch)
    batch_sizes = [args.batch] * n_batch + [resid]

    with torch.no_grad():
        for batch in tqdm(batch_sizes):
            noise = g.make_noise()

            inputs = torch.randn([batch * 2, latent_dim], device=device)
            lerp_t = torch.rand(batch, device=device)

            if args.space == 'w':
                latent = g.get_latent(inputs)
                latent_t0, latent_t1 = latent[::2], latent[1::2]
                latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None])
                latent_e1 = lerp(latent_t0, latent_t1,
                                 lerp_t[:, None] + args.eps)
                latent_e = torch.stack([latent_e0, latent_e1],
                                       1).view(*latent.shape)

            image, _ = g([latent_e], input_is_latent=True, noise=noise)

            if args.crop:
                c = image.shape[2] // 8
                image = image[:, :, c * 3:c * 7, c * 2:c * 6]

            factor = image.shape[2] // 256

            if factor > 1:
                image = F.interpolate(image,
                                      size=(256, 256),
                                      mode='bilinear',
                                      align_corners=False)

            dist = percept(image[::2], image[1::2]).view(
                image.shape[0] // 2) / (args.eps**2)
            distances.append(dist.to('cpu').numpy())

    distances = np.concatenate(distances, 0)

    lo = np.percentile(distances, 1, interpolation='lower')
    hi = np.percentile(distances, 99, interpolation='higher')
    filtered_dist = np.extract(
        np.logical_and(lo <= distances, distances <= hi), distances)

    print('ppl:', filtered_dist.mean())
Ejemplo n.º 10
0
elif args.gan == "wgan-gp":

    discriminator = ResidualDiscriminator(1, dim=128).to(device)
    discriminator.eval()
    discriminator.load_state_dict(torch.load(args.ckpt)['d'])

    g_ema = ResidualGenerator(512, dim=128).to(device)
    g_ema.load_state_dict(torch.load(args.ckpt)['g_ema'])
    g_ema.eval()

    # encoder = Encoder(1, dim=128).to(device)

    # encoder.load_state_dict(torch.load(args.ckptE))
    # encoder.eval()

percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)


def get_npy_idx(patient_n):
    # print(test_metadata[test_metadata["patient_n"].isin(patient_n)])
    return np.arange(95)[test_metadata["patient_n"].isin(patient_n)]


def get_lr_features(
    losses,
    query_images,
    anomaly_score=False,
):

    latents = losses["latents"]
Ejemplo n.º 11
0
import random
import argparse
from tqdm import tqdm
import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
from models import weights_init, Discriminator, Generator
from operation import copy_G_params, load_params, get_dir
from operation import ImageFolder, InfiniteSamplerWrapper
from diffaug import DiffAugment

policy = 'color,translation'
import lpips

percept = lpips.PerceptualLoss(
    model='net-lin',
    net='vgg',
    use_gpu=True if torch.cuda.is_available() else False)
#cccc
#torch.backends.cudnn.benchmark = True


def crop_image_by_part(image, part):
    hw = image.shape[2] // 2
    if part == 0:
        return image[:, :, :hw, :hw]
    if part == 1:
        return image[:, :, :hw, hw:]
    if part == 2:
        return image[:, :, hw:, :hw]
    if part == 3:
        return image[:, :, hw:, hw:]
def train():
    from config import IM_SIZE_AE, BATCH_SIZE_AE, NFC, NBR_CLS, DATALOADER_WORKERS, ITERATION_AE
    from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, SAVE_FOLDER, TRIAL_NAME, LOG_INTERVAL
    from config import DATA_NAME
    from config import data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3

    dataset = PairedMultiDataset(data_root_colorful,
                                 data_root_sketch_1,
                                 data_root_sketch_2,
                                 data_root_sketch_3,
                                 im_size=IM_SIZE_AE,
                                 rand_crop=True)
    print(len(dataset))
    dataloader = iter(DataLoader(dataset, BATCH_SIZE_AE, \
        sampler=InfiniteSamplerWrapper(dataset), num_workers=DATALOADER_WORKERS, pin_memory=True))

    dataset_ss = SelfSupervisedDataset(data_root_colorful,
                                       data_root_sketch_3,
                                       im_size=IM_SIZE_AE,
                                       nbr_cls=NBR_CLS,
                                       rand_crop=True)
    print(len(dataset_ss), len(dataset_ss.frame))
    dataloader_ss = iter(DataLoader(dataset_ss, BATCH_SIZE_AE, \
        sampler=InfiniteSamplerWrapper(dataset_ss), num_workers=DATALOADER_WORKERS, pin_memory=True))

    style_encoder = StyleEncoder(nfc=NFC, nbr_cls=NBR_CLS).cuda()
    content_encoder = ContentEncoder(nfc=NFC).cuda()
    decoder = Decoder(nfc=NFC).cuda()

    opt_c = optim.Adam(content_encoder.parameters(),
                       lr=2e-4,
                       betas=(0.5, 0.999))
    opt_s = optim.Adam(style_encoder.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_d = optim.Adam(decoder.parameters(), lr=2e-4, betas=(0.5, 0.999))

    style_encoder.reset_cls()
    style_encoder.final_cls.cuda()

    from config import PRETRAINED_AE_PATH, PRETRAINED_AE_ITER
    if PRETRAINED_AE_PATH is not None:
        PRETRAINED_AE_PATH = PRETRAINED_AE_PATH + '/models/%d.pth' % PRETRAINED_AE_ITER
        ckpt = torch.load(PRETRAINED_AE_PATH)

        print(PRETRAINED_AE_PATH)

        style_encoder.load_state_dict(ckpt['s'])
        content_encoder.load_state_dict(ckpt['c'])
        decoder.load_state_dict(ckpt['d'])

        opt_c.load_state_dict(ckpt['opt_c'])
        opt_s.load_state_dict(ckpt['opt_s'])
        opt_d.load_state_dict(ckpt['opt_d'])
        print('loaded pre-trained AE')

    style_encoder.reset_cls()
    style_encoder.final_cls.cuda()
    opt_s_cls = optim.Adam(style_encoder.final_cls.parameters(),
                           lr=2e-4,
                           betas=(0.5, 0.999))

    saved_image_folder, saved_model_folder = make_folders(
        SAVE_FOLDER, 'AE_' + TRIAL_NAME)
    log_file_path = saved_image_folder + '/../ae_log.txt'
    log_file = open(log_file_path, 'w')
    log_file.close()
    ## for logging
    losses_sf_consist = AverageMeter()
    losses_cf_consist = AverageMeter()
    losses_cls = AverageMeter()
    losses_rec_rd = AverageMeter()
    losses_rec_org = AverageMeter()
    losses_rec_grey = AverageMeter()

    import lpips
    percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)

    for iteration in tqdm(range(ITERATION_AE)):

        if iteration % (
            (NBR_CLS * 100) // BATCH_SIZE_AE) == 0 and iteration > 1:
            dataset_ss._next_set()
            dataloader_ss = iter(
                DataLoader(dataset_ss,
                           BATCH_SIZE_AE,
                           sampler=InfiniteSamplerWrapper(dataset_ss),
                           num_workers=DATALOADER_WORKERS,
                           pin_memory=True))
            style_encoder.reset_cls()
            opt_s_cls = optim.Adam(style_encoder.final_cls.parameters(),
                                   lr=2e-4,
                                   betas=(0.5, 0.999))

            opt_s.param_groups[0]['lr'] = 1e-4
            opt_d.param_groups[0]['lr'] = 1e-4

        ### 1. train the encoder with self-supervision methods
        rgb_img_rd, rgb_img_org, skt_org, skt_bold, skt_erased, skt_erased_bold, img_idx = next(
            dataloader_ss)
        rgb_img_rd = rgb_img_rd.cuda()
        rgb_img_org = rgb_img_org.cuda()
        img_idx = img_idx.cuda()

        skt_org = F.interpolate(skt_org, size=512).cuda()
        skt_bold = F.interpolate(skt_bold, size=512).cuda()
        skt_erased = F.interpolate(skt_erased, size=512).cuda()
        skt_erased_bold = F.interpolate(skt_erased_bold, size=512).cuda()

        style_encoder.zero_grad()
        decoder.zero_grad()
        content_encoder.zero_grad()

        style_vector_rd, pred_cls_rd = style_encoder(rgb_img_rd)
        style_vector_org, pred_cls_org = style_encoder(rgb_img_org)

        content_feats = content_encoder(skt_org)
        content_feats_bold = content_encoder(skt_bold)
        content_feats_erased = content_encoder(skt_erased)
        content_feats_eb = content_encoder(skt_erased_bold)

        rd = random.randint(0, 3)
        gimg_rd = None
        if rd == 0:
            gimg_rd = decoder(content_feats, style_vector_rd)
        elif rd == 1:
            gimg_rd = decoder(content_feats_bold, style_vector_rd)
        elif rd == 2:
            gimg_rd = decoder(content_feats_erased, style_vector_rd)
        elif rd == 3:
            gimg_rd = decoder(content_feats_eb, style_vector_rd)


        loss_cf_consist = loss_for_list_perm(F.mse_loss, content_feats_bold, content_feats) +\
                            loss_for_list_perm(F.mse_loss, content_feats_erased, content_feats) +\
                                loss_for_list_perm(F.mse_loss, content_feats_eb, content_feats)

        loss_sf_consist = 0
        for loss_idx in range(3):
            loss_sf_consist += -F.cosine_similarity(style_vector_rd[loss_idx], style_vector_org[loss_idx].detach()).mean() + \
                                    F.cosine_similarity(style_vector_rd[loss_idx], style_vector_org[loss_idx][torch.randperm(BATCH_SIZE_AE)].detach()).mean()

        loss_cls = F.cross_entropy(pred_cls_rd, img_idx) + F.cross_entropy(
            pred_cls_org, img_idx)
        loss_rec_rd = F.mse_loss(gimg_rd, rgb_img_org)
        if DATA_NAME != 'shoe':
            loss_rec_rd += percept(
                F.adaptive_avg_pool2d(gimg_rd, output_size=256),
                F.adaptive_avg_pool2d(rgb_img_org, output_size=256)).sum()
        else:
            loss_rec_rd += F.l1_loss(gimg_rd, rgb_img_org)

        loss_total = loss_cls + loss_sf_consist + loss_rec_rd + loss_cf_consist  #+ loss_kl_c + loss_kl_s
        loss_total.backward()

        opt_s.step()
        opt_s_cls.step()
        opt_c.step()
        opt_d.step()

        ### 2. train as AutoEncoder
        rgb_img, skt_img_1, skt_img_2, skt_img_3 = next(dataloader)

        rgb_img = rgb_img.cuda()

        rd = random.randint(0, 3)
        if rd == 0:
            skt_img = skt_img_1
        elif rd == 1:
            skt_img = skt_img_2
        else:
            skt_img = skt_img_3

        skt_img = F.interpolate(skt_img, size=512).cuda()

        style_encoder.zero_grad()
        decoder.zero_grad()
        content_encoder.zero_grad()

        style_vector, _ = style_encoder(rgb_img)
        content_feats = content_encoder(skt_img)
        gimg = decoder(content_feats, style_vector)

        loss_rec_org = F.mse_loss(gimg, rgb_img)
        if DATA_NAME != 'shoe':
            loss_rec_org += percept(
                F.adaptive_avg_pool2d(gimg, output_size=256),
                F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum()
        #else:
        #    loss_rec_org += F.l1_loss(gimg, rgb_img)

        loss_rec = loss_rec_org
        if DATA_NAME == 'shoe':
            ### the grey image reconstruction
            perm = true_randperm(BATCH_SIZE_AE)
            gimg_perm = decoder(content_feats, [s[perm] for s in style_vector])
            gimg_grey = gimg_perm.mean(dim=1, keepdim=True)
            real_grey = rgb_img.mean(dim=1, keepdim=True)
            loss_rec_grey = F.mse_loss(gimg_grey, real_grey)
            loss_rec += loss_rec_grey
        loss_rec.backward()

        opt_s.step()
        opt_d.step()
        opt_c.step()

        ### Logging
        losses_cf_consist.update(loss_cf_consist.mean().item(), BATCH_SIZE_AE)
        losses_sf_consist.update(loss_sf_consist.mean().item(), BATCH_SIZE_AE)
        losses_cls.update(loss_cls.mean().item(), BATCH_SIZE_AE)
        losses_rec_rd.update(loss_rec_rd.item(), BATCH_SIZE_AE)
        losses_rec_org.update(loss_rec_org.item(), BATCH_SIZE_AE)
        if DATA_NAME == 'shoe':
            losses_rec_grey.update(loss_rec_grey.item(), BATCH_SIZE_AE)

        if iteration % LOG_INTERVAL == 0:
            log_msg = 'Train Stage 1: AE: \nrec_rd: %.4f  rec_org: %.4f  cls: %.4f  style_consist: %.4f  content_consist: %.4f  rec_grey: %.4f'%(losses_rec_rd.avg, \
                    losses_rec_org.avg, losses_cls.avg, losses_sf_consist.avg, losses_cf_consist.avg, losses_rec_grey.avg)

            print(log_msg)

            if log_file_path is not None:
                log_file = open(log_file_path, 'a')
                log_file.write(log_msg + '\n')
                log_file.close()

            losses_sf_consist.reset()
            losses_cls.reset()
            losses_rec_rd.reset()
            losses_rec_org.reset()
            losses_cf_consist.reset()
            losses_rec_grey.reset()

        if iteration % SAVE_IMAGE_INTERVAL == 0:
            vutils.save_image(torch.cat([
                rgb_img_rd,
                F.interpolate(skt_org.repeat(1, 3, 1, 1), size=512), gimg_rd
            ]),
                              '%s/rd_%d.jpg' % (saved_image_folder, iteration),
                              normalize=True,
                              range=(-1, 1))
            if DATA_NAME != 'shoe':
                with torch.no_grad():
                    perm = true_randperm(BATCH_SIZE_AE)
                    gimg_perm = decoder([c for c in content_feats],
                                        [s[perm] for s in style_vector])
            vutils.save_image(torch.cat([
                rgb_img,
                F.interpolate(skt_img.repeat(1, 3, 1, 1), size=512), gimg,
                gimg_perm
            ]),
                              '%s/org_%d.jpg' %
                              (saved_image_folder, iteration),
                              normalize=True,
                              range=(-1, 1))

        if iteration % SAVE_MODEL_INTERVAL == 0:
            print('Saving history model')
            torch.save(
                {
                    's': style_encoder.state_dict(),
                    'd': decoder.state_dict(),
                    'c': content_encoder.state_dict(),
                    'opt_c': opt_c.state_dict(),
                    'opt_s_cls': opt_s_cls.state_dict(),
                    'opt_s': opt_s.state_dict(),
                    'opt_d': opt_d.state_dict(),
                }, '%s/%d.pth' % (saved_model_folder, iteration))

    torch.save(
        {
            's': style_encoder.state_dict(),
            'd': decoder.state_dict(),
            'c': content_encoder.state_dict(),
            'opt_c': opt_c.state_dict(),
            'opt_s_cls': opt_s_cls.state_dict(),
            'opt_s': opt_s.state_dict(),
            'opt_d': opt_d.state_dict(),
        }, '%s/%d.pth' % (saved_model_folder, ITERATION_AE))
Ejemplo n.º 13
0
def projector_factor_fn(ckpt, fact, files):
    device = "cuda"

    parser = argparse.ArgumentParser()
    # parser.add_argument("--ckpt", type=str, required=True)
    parser.add_argument("--size", type=int, default=256)
    parser.add_argument("--lr_rampup", type=float, default=0.05)
    parser.add_argument("--lr_rampdown", type=float, default=0.25)
    parser.add_argument("--lr", type=float, default=0.1)
    parser.add_argument("--noise", type=float, default=0.05)
    parser.add_argument("--noise_ramp", type=float, default=0.75)
    parser.add_argument("--step", type=int, default=1000)
    parser.add_argument("--noise_regularize", type=float, default=1e5)
    parser.add_argument("--mse", type=float, default=0)
    parser.add_argument("--w_plus", action="store_true")
    parser.add_argument("--device", type=str, default="cuda")
    # parser.add_argument("--fact", type=str, required=True)
    # parser.add_argument("files", metavar="FILES", nargs="+")

    args = parser.parse_args()
    eigvec = torch.load(fact)["eigvec"].to(args.device)  # args.fact
    eigvec.requires_grad = False

    n_mean_degree = 10000
    n_mean_weight = 10000

    resize = min(args.size, 256)

    transform = transforms.Compose([
        transforms.Resize(resize),
        transforms.CenterCrop(resize),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    imgs = []

    for imgfile in files:  # args.files
        img = transform(Image.open(imgfile).convert("RGB"))
        imgs.append(img)

    imgs = torch.stack(imgs, 0).to(device)

    g_ema = Generator(args.size, 512, 8)
    g_ema.load_state_dict(torch.load(ckpt)["g_ema"], strict=False)  # args.ckpt
    g_ema.eval()
    g_ema = g_ema.to(device)
    trunc = g_ema.mean_latent(4096)

    with torch.no_grad():
        noise_sample = torch.randn(1, 512, device=device)
        latent = g_ema.get_latent(noise_sample)

        degree_sample = torch.randn(n_mean_degree, 1, device=device) * 1000
        weight_sample = torch.randn(
            n_mean_weight, 512, device=device).unsqueeze(1)  # row of factor

        degree_mean = degree_sample.mean(0)
        degree_std = ((degree_sample - degree_mean).pow(2).sum() /
                      n_mean_degree)**0.5

        weight_mean = weight_sample.mean(0)
        weight_std = ((weight_sample - weight_mean).pow(2).sum() /
                      n_mean_weight)**0.5

        direction = torch.mm(eigvec, weight_mean.T)

    percept = lpips.PerceptualLoss(model="net-lin",
                                   net="vgg",
                                   use_gpu=device.startswith("cuda"))

    noises_single = g_ema.make_noise()
    noises = []
    for noise in noises_single:
        noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())

    direction_in = direction.T
    weight_in = weight_mean

    weight_mean.requires_grad = True

    for noise in noises:
        noise.requires_grad = True

    optimizer = optim.Adam([weight_mean] + noises, lr=args.lr)

    pbar = tqdm(range(args.step))
    latent_path = []
    weight_path = []

    imgs.detach()

    for i in pbar:
        t = i / args.step
        lr = get_lr(t, args.lr)
        optimizer.param_groups[0]["lr"] = lr
        noise_strength = weight_std * args.noise * max(
            0, 1 - t / args.noise_ramp)**2
        weight_n = latent_noise(weight_in, noise_strength.item())
        direction_n = torch.mm(weight_in, eigvec)

        img_gen, _ = g_ema([direction_n], input_is_latent=True, noise=noises)

        batch, channel, height, width = img_gen.shape

        if height > 256:
            factor = height // 256

            img_gen = img_gen.reshape(batch, channel, height // factor, factor,
                                      width // factor, factor)
            img_gen = img_gen.mean([3, 5])

        p_loss = percept(img_gen, imgs).sum()
        n_loss = noise_regularize(noises)
        mse_loss = F.mse_loss(img_gen, imgs)

        loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        noise_normalize_(noises)

        if (i + 1) % 100 == 0:
            direction_in = torch.mm(weight_in, eigvec)
            latent_path.append(direction_in.detach().clone())
            weight_path.append(weight_in.detach().clone())

        pbar.set_description((
            f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
            f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"))

    img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises)

    filename = os.path.splitext(os.path.basename(
        files[0]))[0] + ".pt"  # args.files[0]

    img_ar = make_image(img_gen)

    result_file = {}

    # factor_base_path_1 = './models/'

    for i, input_name in enumerate(files):  # args.files
        noise_single = []
        for noise in noises:
            noise_single.append(noise[i:i + 1])

        result_file[input_name] = {
            "img": img_gen[i],
            "latent": latent_path[i],
            "weight": weight_path[i],
            "noise": noise_single,
        }

        img_name = os.path.splitext(
            os.path.basename(input_name))[0] + "-project.png"
        pil_img = Image.fromarray(img_ar[i])

        save_latent_image_path = factor_base_path + img_name
        pil_img.save(save_latent_image_path)  # add factor_base_path

    save_latent_code_path = factor_base_path + filename
    torch.save(result_file, save_latent_code_path)  # factor_base_path
    return save_latent_image_path, save_latent_code_path
Ejemplo n.º 14
0
    #================================
    loss_l1_fn = nn.L1Loss()
    loss_vgg_fn = VGGLoss(device, n_channels=3)
    if (args.gan_loss_type == "lsgan"):
        loss_adv_fn = LSGANLoss(device)
    elif (args.gan_loss_type == "hinge"):
        loss_adv_fn = HingeGANLoss()
    else:
        NotImplementedError()

    if (args.rec_loss_type == "vgg"):
        loss_rec_fn = VGGLoss(device, n_channels=3)
    elif (args.rec_loss_type == "lpips"):
        if (args.device == "gpu"):
            loss_rec_fn = lpips.PerceptualLoss(model='net-lin',
                                               net='vgg',
                                               use_gpu=True)
        else:
            loss_rec_fn = lpips.PerceptualLoss(model='net-lin',
                                               net='vgg',
                                               use_gpu=False)
    else:
        NotImplementedError()

    #================================
    # モデルの学習
    #================================
    print("Starting Training Loop...")
    n_print = 1
    step = 0
    for epoch in tqdm(range(args.n_epoches), desc="epoches"):
Ejemplo n.º 15
0
def project(path_ckpt, path_files, step=1000):
    device = "cuda"

    parser = argparse.ArgumentParser()
    parser.add_argument('-f', type=str, help='jup kernel')
    # parser.add_argument("--ckpt", type=str, required=True)
    parser.add_argument("--size", type=int, default=512)
    parser.add_argument("--lr_rampup", type=float, default=0.05)
    parser.add_argument("--lr_rampdown", type=float, default=0.25)
    parser.add_argument("--lr", type=float, default=0.1)
    parser.add_argument("--noise", type=float, default=0.05)
    parser.add_argument("--noise_ramp", type=float, default=0.75)
    parser.add_argument("--step", type=int, default=1000)
    parser.add_argument("--noise_regularize", type=float, default=1e5)
    parser.add_argument("--mse", type=float, default=0)
    # parser.add_argument("--w_plus", action="store_true")
    # parser.add_argument("files", metavar="FILES", nargs="+")

    args = parser.parse_args()
    args.ckpt = path_ckpt
    args.files = path_files
    args.w_plus = False
    args.step = step

    n_mean_latent = 10000
    resize = min(args.size, 256)
    transform = transforms.Compose([
        transforms.Resize(resize),
        transforms.CenterCrop(resize),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    imgs = []
    for imgfile in args.files:
        img = transform(Image.open(imgfile).convert("RGB"))
        imgs.append(img)
    imgs = torch.stack(imgs, 0).to(device)

    g_ema = Generator(args.size, 512, 8)
    g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
    g_ema.eval()
    g_ema = g_ema.to(device)

    with torch.no_grad():
        noise_sample = torch.randn(n_mean_latent, 512, device=device)
        latent_out = g_ema.style(noise_sample)

        latent_mean = latent_out.mean(0)
        latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5

    percept = lpips.PerceptualLoss(
        model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
    )

    noises_single = g_ema.make_noise()
    noises = []
    for noise in noises_single:
        noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())

    latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1)

    if args.w_plus:
        latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)

    latent_in.requires_grad = True

    for noise in noises:
        noise.requires_grad = True

    optimizer = optim.Adam([latent_in] + noises, lr=args.lr)

    pbar = tqdm(range(args.step))
    latent_path = []

    for i in pbar:
        t = i / args.step
        lr = get_lr(t, args.lr)
        optimizer.param_groups[0]["lr"] = lr
        noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2
        latent_n = latent_noise(latent_in, noise_strength.item())

        img_gen, _ = g_ema([latent_n], input_is='latent', noise=noises)

        batch, channel, height, width = img_gen.shape

        if height > resize:
            factor = height // resize

            img_gen = img_gen.reshape(
                batch, channel, height // factor, factor, width // factor, factor
            )
            img_gen = img_gen.mean([3, 5])

        p_loss = percept(img_gen, imgs).sum()
        n_loss = noise_regularize(noises)
        mse_loss = F.mse_loss(img_gen, imgs)

        loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        noise_normalize_(noises)

        if (i + 1) % 100 == 0:
            latent_path.append(latent_in.detach().clone())

        pbar.set_description((
            f"perceptual: {p_loss.item():.8f}; noise regularize: {n_loss.item():.8f}; mse: {mse_loss.item():.8f}; lr: {lr:.4f}"
        ))

    img_gen, _ = g_ema([latent_path[-1]], input_is='latent', noise=noises)

    filename = os.path.splitext(os.path.basename(args.files[0]))[0] + ".pt"

    img_ar = make_image(img_gen)

    result_file = {}
    for i, input_name in enumerate(args.files):
        noise_single = []
        for noise in noises:
            noise_single.append(noise[i: i + 1])

        result_file[input_name] = {
            "img": img_gen[i],
            "latent": latent_in[i],
            "noise": noise_single,
        }

        img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png"
        pil_img = Image.fromarray(img_ar[i])
        pil_img.save(img_name)

    torch.save(result_file, filename)
    print(filename)

    return img_gen, latent_path, latent_in
Ejemplo n.º 16
0
    g_ema.eval()
    g_ema = g_ema.to(device)
    g_ema = MyDataParallel(g_ema, device_ids=range(args.n_gpu))

    if args.feature_extractor == "d":
        discriminator = Discriminator(args.size, channel_multiplier=2)
        discriminator.load_state_dict(torch.load(args.ckpt)['d'])
        discriminator.eval()
        discriminator = discriminator.to(device)
        discriminator = MyDataParallel(discriminator,
                                       device_ids=range(args.n_gpu))
        percept = d_feat_loss

    elif args.feature_extractor == "vgg":
        percept = lpips.PerceptualLoss(model='net-lin',
                                       net='vgg',
                                       use_gpu=device.startswith('cuda'),
                                       gpu_ids=range(args.n_gpu))

    # checkpoint = torch.load("/scratch/gobi2/lechang/trained_vae.pth")
    # vae = VariationalAutoEncoderLite(5)
    # vae.load_state_dict(checkpoint['ae'])
    # vae = vae.to(device)
    # vae.eval()

    print("loaded models")
    if args.latent_space == "w":
        with torch.no_grad():
            noise_sample = torch.randn(n_mean_latent, 512, device=device)
            latent_out = g_ema.style(noise_sample)

            latent_mean = latent_out.mean(0)
Ejemplo n.º 17
0
    parser.add_argument('--eps', type=float, default=1e-4)
    parser.add_argument('--crop', action='store_true')
    parser.add_argument('ckpt', metavar='CHECKPOINT')

    args = parser.parse_args()

    latent_dim = 512

    ckpt = torch.load(args.ckpt)

    g = Generator(args.size, latent_dim, 8).to(device)
    g.load_state_dict(ckpt['g_ema'])
    g.eval()

    percept = lpips.PerceptualLoss(model='net-lin',
                                   net='vgg',
                                   use_gpu=device.startswith('cuda'))

    distances = []

    n_batch = args.n_sample // args.batch
    resid = args.n_sample - (n_batch * args.batch)
    batch_sizes = [args.batch] * n_batch + [resid]

    with torch.no_grad():
        for batch in tqdm(batch_sizes):
            noise = g.make_noise()

            inputs = torch.randn([batch * 2, latent_dim], device=device)
            lerp_t = torch.rand(batch, device=device)
Ejemplo n.º 18
0
    def __init__(self,
                 x,
                 gma,
                 device,
                 lr=0.1,
                 steps='1000',
                 task='invert',
                 search_space='W+',
                 search_noise=True,
                 project=True,
                 start_layer=0,
                 end_layer=5,
                 discriminator=None,
                 cls_alpha=0,
                 mask=1,
                 mse_weight=1,
                 lpips_alpha=0,
                 r_alpha=0.1):
        """

        :param x:
        :param gma:
        :param device:
        :param lr:
        :param steps:
        :param task:
        :param search_space: W, W+, Z, Z+
        :param search_noise:
        :param project:
        :param start_layer:
        :param end_layer:
        :param discriminator:
        :param cls_alpha:
        :param mask:
        :param mse_weight:
        :param lpips_alpha:
        :param r_alpha:
        """
        super().__init__()
        # self.config = config
        # if config['image_size'][0] != config['image_size'][1]:
        #     raise Exception('Non-square images are not supported yet.')

        self.task = task
        self.search_space = search_space
        self.search_noise = search_noise
        self.project = project
        self.steps = steps
        self.start_layer = start_layer
        self.end_layer = end_layer
        self.dead_zone_linear = 0
        self.dead_zone_linear_alpha = 0.1
        self.device = device
        self.geocross_alpha = 0.1
        self.cls_alpha = cls_alpha
        self.lpips_alpha = lpips_alpha
        self.r_alpha = r_alpha
        self.mask = mask
        self.mse_weight = mse_weight

        self.layer_in = None
        self.best = None
        self.skip = None
        self.lr = lr
        self.lr_record = []
        self.current_step = 0

        self.original_imgs = x.to(device)

        self.discriminator = discriminator
        if self.discriminator is not None:
            self.discriminator = self.discriminator.to(device)

        if self.task == 'separate':
            bs = self.original_imgs.shape[0] * 2
        else:
            bs = self.original_imgs.shape[0]

        # self.downsampler_1024_256 = BicubicDownSample(4)
        # self.downsampler_1024_image = BicubicDownSample(1024 // config['image_size'][0])
        # self.downsampler_image_256 = BicubicDownSample(config['image_size'][0] // 256)
        # Load models and pre-trained weights
        self.gen = gma.to(device)
        self.gen.start_layer = start_layer
        self.gen.end_layer = end_layer

        for p in self.gen.parameters():
            p.requires_grad = False

        self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2)
        self.plrelu = torch.nn.LeakyReLU(negative_slope=5)

        # if self.verbose: print("\tRunning Mapping Network")
        with torch.no_grad():
            # torch.manual_seed(0)
            # latent = torch.randn((1000000, 512), dtype=torch.float32, device="cuda")
            # latent_out = torch.nn.LeakyReLU(5)(self.gen.style(latent))

            latent_p = self.plrelu(
                self.gen.style(
                    torch.randn((500000, 512),
                                dtype=torch.float32,
                                device="cuda"))).double()
            self.mu = latent_p.mean(dim=0, keepdim=True)
            self.Sigma = (latent_p - self.mu).T @ (latent_p -
                                                   self.mu) / latent_p.shape[0]
            d, V = torch.symeig(self.Sigma, eigenvectors=True)
            # small eigenvalues do not get overamplified.
            D = torch.diag(1. / torch.sqrt(d + 1e-18))
            # whitening matrix
            # W = np.dot(np.dot(V, D), V.T) # ZCA whitening
            self.W = (V @ D).float()  # PCA whitening
            self.W_inv = torch.inverse(self.W)

            latent_p = latent_p.float()
            self.mu = self.mu.float().unsqueeze(0).to(device)
            self.Sigma = self.Sigma.float().to(device)
            self.gaussian_fit = {
                "mean": latent_p.mean(0).to(device),
                "std": latent_p.std(0).to(device)
            }

            del latent_p
            torch.cuda.empty_cache()

        # self.mpl = MappingProxy(torch.load('gaussian_fit.pt'))

        self.percept = lpips.PerceptualLoss(model="net-lin",
                                            net="vgg",
                                            use_gpu=device.startswith("cuda"))

        # # load a classifier
        # self.cls = imagenet_models.resnet50()
        # state_dict = torch.load('imagenet_l2_3_0.pt')['model']
        # new_dict = OrderedDict()
        #
        # for key in state_dict.keys():
        #     if 'module.model' in key:
        #         new_dict[key[13:]] = state_dict[key]
        #
        # self.cls.load_state_dict(new_dict)
        # self.cls.to(config['device'])

        # initialization
        # self.scalar = torch.ones(bs, requires_grad=True).to(device)
        self.scalar = torch.ones((bs, 1, 1, 1),
                                 dtype=torch.float,
                                 requires_grad=True,
                                 device='cuda')
        if start_layer == 0:
            noises_single = self.gen.make_noise(bs)
            self.noises = []
            for noise in noises_single:
                self.noises.append(noise.normal_())
            if self.search_space == 'W':
                # # self.latent_z = torch.randn(
                # #             (bs, self.gen.n_latent, self.gen.style_dim),
                # #             dtype=torch.float,
                # #             requires_grad=True, device='cuda')
                # with torch.no_grad():
                #     self.latent_z = self.gen.style(F.normalize(torch.randn(bs, self.gen.n_latent, self.gen.style_dim), p=2, dim=2).to(device)) # random w
                #     # self.latent_z = self.gen.style(F.normalize(torch.randn(bs, 1, self.gen.style_dim), p=2, dim=2).repeat(1, self.gen.n_latent, 1).to(device))  # random w
                #     # self.latent_z = self.gen.mean_latent(16384).unsqueeze(1).repeat(bs, self.gen.n_latent, 1).to(device) # mean w
                # self.latent_z.requires_grad = True
                # Generate latent tensor
                self.latent = torch.randn((bs, 1, self.gen.style_dim),
                                          dtype=torch.float,
                                          requires_grad=True,
                                          device='cuda')
            elif self.search_space == 'W+':
                self.latent = torch.randn(
                    (bs, self.gen.n_latent, self.gen.style_dim),
                    dtype=torch.float,
                    requires_grad=True,
                    device='cuda')
                # with torch.no_grad():
                #     self.latent = self.gen.style(torch.randn(bs, self.gen.n_latent, self.gen.style_dim).to(device)) # random w
                # self.latent.requires_grad = True
            elif self.search_space == 'Z':
                self.latent = torch.randn((bs, 1, self.gen.style_dim),
                                          dtype=torch.float,
                                          requires_grad=True,
                                          device='cuda')
            elif self.search_space == 'Z+':
                # self.latent_z = torch.randn(
                #             (bs, self.gen.style_dim),
                #             dtype=torch.float,
                #             requires_grad=True, device='cuda')
                self.latent_z = torch.randn(
                    (bs, self.gen.n_latent, self.gen.style_dim),
                    dtype=torch.float,
                    requires_grad=True,
                    device='cuda')
                self.latent_w = self.gen.style(self.latent_z)
            else:
                raise ValueError("searching_space incorrect")

            self.gen_outs = [None]
        else:
            # restore noises
            self.noises = torch.load(config['saved_noises'][0])
            self.latent_z = torch.load(config['saved_noises'][1]).to(
                config['device'])
            self.gen_outs = torch.load(config['saved_noises'][2])
            self.latent_z.requires_grad = True
Ejemplo n.º 19
0
    imgs = torch.stack(imgs, 0).to(device)

    g_ema = Generator(args.size, 512, 8)
    g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
    g_ema.eval()
    g_ema = g_ema.to(device)

    with torch.no_grad():
        noise_sample = torch.randn(n_mean_latent, 512, device=device)
        latent_out = g_ema.mapping_network(noise_sample)

        latent_mean = latent_out.mean(0)
        latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5

    percept = lpips.PerceptualLoss(
        model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
    )

    noises_single = g_ema.make_noise()
    noises = []
    for noise in noises_single:
        noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())

    latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1)

    if args.w_plus:
        latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)

    latent_in.requires_grad = True

    for noise in noises:
Ejemplo n.º 20
0
    def train(self, ref_images, epochs=100, start_lr=0.1):
        '''Train ...'''

        if ref_images.dim() < 4:
            ref_images = ref_images.unsqueeze(0)
        ref_images = ref_images.to(os.environ["DEVICE"])

        ref_batch, channel, ref_height, ref_width = ref_images.shape
        assert ref_height == ref_width

        noise_var_list = []
        for noise in self.generator.make_noise():
            normal_noise = noise.repeat(ref_batch, 1, 1, 1).normal_()
            normal_noise.requires_grad = True
            noise_var_list.append(normal_noise)

        n_mean_latent = 10000
        latent_out = torch.randn(n_mean_latent, 512, device=self.device)
        latent_mean = latent_out.mean(0)
        latent_std = ((latent_out - latent_mean).pow(2).sum() /
                      n_mean_latent)**0.5
        del noise_sample, latent_out

        latent_var = latent_mean.detach().clone().unsqueeze(0).repeat(
            ref_batch, 1)
        latent_var.requires_grad = True

        optimizer = optim.Adam([latent_var] + noise_var_list, lr=start_lr)

        percept = lpips.PerceptualLoss(
            model="net-lin",
            net="vgg",
            use_gpu=os.environ["DEVICE"].startswith("cuda"))

        progress_bar = tqdm(range(epochs))
        noise_level = 0.05
        noise_ramp = 0.07
        self.generator.train()

        for i in progress_bar:
            t = i / epochs

            lr = get_lr(t, start_lr)
            optimizer.param_groups[0]["lr"] = lr

            # Inject Noise to latent_n
            noise_strength = latent_std * noise_level * max(
                0, 1 - t / noise_ramp)**2
            latent_n = latent_var + torch.randn_like(
                latent_var) * noise_strength.item()
            gen_images = self.generator(latent_n, noise=noise_var_list)

            batch, channel, height, width = gen_images.shape
            if height > ref_height:
                factor = height // ref_height
                gen_images = gen_images.reshape(batch, channel,
                                                height // factor, factor,
                                                width // factor, factor)
                gen_images = gen_images.mean([3, 5])

            p_loss = percept(gen_images, ref_images).sum()
            n_loss = noise_loss(noise_var_list)
            mse_loss = F.mse_loss(gen_images, ref_images)

            loss = p_loss + 1e5 * n_loss + 0.2 * mse_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            noise_normalize_(noise_var_list)

            progress_bar.set_description((
                f"Loss = perceptual: {p_loss.item():.4f}; noise: {n_loss.item():.4f};"
                f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"))
        last_latent = latent_n.detach().clone()

        self.generator.eval()

        del noise_var_list
        torch.cuda.empty_cache()

        # maybe the best latent ?
        return last_latent