Exemplo n.º 1
0
def main(conf):
    device = "cuda:0" if torch.cuda.is_available() else 'cpu'
    beta_schedule = "linear"
    beta_start = 1e-4
    beta_end = 2e-2
    n_timestep = 1000

    conf.distributed = dist.get_world_size() > 1

    transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )

    train_set = MultiResolutionDataset(
        conf.dataset.path, transform, conf.dataset.resolution
    )
    train_sampler = dist.data_sampler(
        train_set, shuffle=True, distributed=conf.distributed
    )
    train_loader = conf.training.dataloader.make(train_set, sampler=train_sampler)

    model = UNet(
        conf.model.in_channel,
        conf.model.channel,
        channel_multiplier=conf.model.channel_multiplier,
        n_res_blocks=conf.model.n_res_blocks,
        attn_strides=conf.model.attn_strides,
        dropout=conf.model.dropout,
        fold=conf.model.fold,
    )
    model = model.to(device)
    ema = UNet(
        conf.model.in_channel,
        conf.model.channel,
        channel_multiplier=conf.model.channel_multiplier,
        n_res_blocks=conf.model.n_res_blocks,
        attn_strides=conf.model.attn_strides,
        dropout=conf.model.dropout,
        fold=conf.model.fold,
    )
    ema = ema.to(device)

    if conf.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = conf.training.optimizer.make(model.parameters())
    scheduler = conf.training.scheduler.make(optimizer)

    betas = make_beta_schedule(beta_schedule, beta_start, beta_end, n_timestep)
    diffusion = GaussianDiffusion(betas).to(device)

    train(conf, train_loader, model, ema, diffusion, optimizer, scheduler, device)
def main(conf):
    wandb = None
    if dist.is_primary() and conf.evaluate.wandb:
        wandb = load_wandb()
        wandb.init(project="denoising diffusion")

    device = "cuda"
    beta_schedule = "linear"

    conf.distributed = dist.get_world_size() > 1

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])

    train_set = MultiResolutionDataset(conf.dataset.path, transform,
                                       conf.dataset.resolution)
    train_sampler = dist.data_sampler(train_set,
                                      shuffle=True,
                                      distributed=conf.distributed)
    train_loader = conf.training.dataloader.make(train_set,
                                                 sampler=train_sampler)

    model = conf.model.make()
    model = model.to(device)
    ema = conf.model.make()
    ema = ema.to(device)

    if conf.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = conf.training.optimizer.make(model.parameters())
    scheduler = conf.training.scheduler.make(optimizer)

    if conf.ckpt is not None:
        ckpt = torch.load(conf.ckpt, map_location=lambda storage, loc: storage)

        if conf.distributed:
            model.module.load_state_dict(ckpt["model"])

        else:
            model.load_state_dict(ckpt["model"])

        ema.load_state_dict(ckpt["ema"])

    betas = conf.diffusion.beta_schedule.make()
    diffusion = GaussianDiffusion(betas).to(device)

    train(conf, train_loader, model, ema, diffusion, optimizer, scheduler,
          device, wandb)
Exemplo n.º 3
0
 def train_dataloader(self):
     transform = transforms.Compose(
         [
             transforms.RandomVerticalFlip(p=0.5 if self.vflip else 0),
             transforms.RandomHorizontalFlip(p=0.5 if self.hflip else 0),
             transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
         ]
     )
     dataset = MultiResolutionDataset(self.path, transform, self.size)
     loader = data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=0)
     return loader
Exemplo n.º 4
0
    def set_dataset(self):
        args = self.args
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
                                 inplace=True),
        ])

        self.dataset = MultiResolutionDataset(args.path, transform, args.size)
        self.loader = data.DataLoader(
            self.dataset,
            batch_size=args.batch,
            sampler=data_sampler(self.dataset,
                                 shuffle=True,
                                 distributed=args.distributed),
            drop_last=True,
        )
Exemplo n.º 5
0
        discriminator = nn.parallel.DistributedDataParallel(
            discriminator,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
        )

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])

    # dataset_src = MultiResolutionDataset(args.path_src, transform, args.size)
    dataset_src = MultiResolutionDataset(args.path_src, transform, 256)
    loader_src = data.DataLoader(
        dataset_src,
        batch_size=args.batch,
        sampler=data_sampler(dataset_src,
                             shuffle=True,
                             distributed=args.distributed),
        drop_last=True,
    )
    # dataset_norm = MultiResolutionDataset(args.path_norm, transform, args.size)
    dataset_norm = MultiResolutionDataset(args.path_norm, transform, 256)
    loader_norm = data.DataLoader(
        dataset_norm,
        batch_size=args.batch,
        sampler=data_sampler(dataset_norm,
                             shuffle=True,
Exemplo n.º 6
0
def setup_and_run(device, args):

    os.makedirs(f'sample_{args.name}', exist_ok=True)
    os.makedirs(f'checkpoint_{args.name}', exist_ok=True)

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    args.distributed = n_gpu > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    args.latent = 512
    args.n_mlp = 8

    args.start_iter = 0

    generator = Generator(
        args.size,
        args.latent,
        args.n_mlp,
        channel_multiplier=args.channel_multiplier).to(device)
    discriminator = Discriminator(
        args.size, channel_multiplier=args.channel_multiplier).to(device)
    g_ema = Generator(args.size,
                      args.latent,
                      args.n_mlp,
                      channel_multiplier=args.channel_multiplier).to(device)
    g_ema.eval()
    accumulate(g_ema, generator, 0)

    g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
    d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)

    g_optim = optim.Adam(
        generator.parameters(),
        lr=args.lr * g_reg_ratio,
        betas=(0**g_reg_ratio, 0.99**g_reg_ratio),
    )
    d_optim = optim.Adam(
        discriminator.parameters(),
        lr=args.lr * d_reg_ratio,
        betas=(0**d_reg_ratio, 0.99**d_reg_ratio),
    )

    if args.ckpt is not None:
        print('load model:', args.ckpt)

        ckpt = torch.load(args.ckpt)

        try:
            ckpt_name = os.path.basename(args.ckpt)
            args.start_iter = int(os.path.splitext(ckpt_name)[0])

        except ValueError:
            pass

        generator.load_state_dict(ckpt['g'], strict=False)
        discriminator.load_state_dict(ckpt['d'], strict=False)
        g_ema.load_state_dict(ckpt['g_ema'], strict=False)

        g_optim.load_state_dict(ckpt['g_optim'])
        d_optim.load_state_dict(ckpt['d_optim'])

    if args.distributed:
        generator = nn.parallel.DistributedDataParallel(
            generator,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
        )

        discriminator = nn.parallel.DistributedDataParallel(
            discriminator,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
        )

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])

    dataset = MultiResolutionDataset(args.path, transform, args.size)
    loader = data.DataLoader(
        dataset,
        batch_size=args.batch,
        sampler=data_sampler(dataset,
                             shuffle=True,
                             distributed=args.distributed),
        drop_last=True,
    )

    if get_rank() == 0 and wandb is not None and args.wandb:
        wandb.init(project='stylegan 2')

    train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device)
        d_optim.load_state_dict(e_ckpt["d_optim"])

        try:
            ckpt_name = os.path.basename(args.e_ckpt)
            args.start_iter = int(
                os.path.splitext(ckpt_name.split('_')[-1])[0])
        except ValueError:
            pass

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])

    dataset = MultiResolutionDataset(args.data, transform, args.size)
    loader = data.DataLoader(
        dataset,
        batch_size=args.batch,
        sampler=data_sampler(dataset, shuffle=True),
        drop_last=True,
    )

    test_dataset = MultiResolutionDataset(args.test_data, transform, args.size)

    test_loader = data.DataLoader(
        test_dataset,
        batch_size=args.val_batch,
        sampler=data_sampler(test_dataset, shuffle=True),
        drop_last=True,
    )
Exemplo n.º 8
0
        generator.module.load_state_dict(ckpt['generator'])
        discriminator.module.load_state_dict(ckpt['discriminator'])
        g_running.load_state_dict(ckpt['g_running'])
        g_optimizer.load_state_dict(ckpt['g_optimizer'])
        d_optimizer.load_state_dict(ckpt['d_optimizer'])

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.Resize((8, 8)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    dataset1 = CondDataset(args.path, transform, transform)
    dataset2 = MultiResolutionDataset(args.path, transform)
    if args.sched:
        args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
        args.batch = {
            4: 512,
            8: 256,
            16: 128,
            32: 64,
            64: 32,
            128: 32,
            256: 32
        }

    else:
        args.lr = {}
        args.batch = {}
Exemplo n.º 9
0
        )

    if args.mirror_augment:
        transform = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
            ]
        )
    else:
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
            ]
        )

    dataset = MultiResolutionDataset(args.path, transform, args.size, args.use_label, metadata, categories)
    loader = data.DataLoader(
        dataset,
        batch_size=args.batch,
        sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
        drop_last=True,
    )

    if get_rank() == 0 and wandb is not None and args.wandb:
        wandb.init(project='stylegan 2')

    train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
Exemplo n.º 10
0
def eval(args, latent_sampler, g_ema, inception, device, config):

    if g_ema is not None: 
        g_ema.eval()

    """
    Cast FID calculation spec
    """
    if hasattr(config.train_params, "extra_pre_resize"):
        real_data_res = config.train_params.extra_pre_resize
    else: # StyleGAN2 baseline
        assert config.train_params.styleGAN2_baseline
        real_data_res = config.train_params.full_size
    assert real_data_res in {128, 256}, "In this paper, we only benchmark in size {128, 256}. Got {}.".format(real_data_res)

    eval_gen_res = real_data_res * args.scale

    # InfinityGAN is trained with larger image, so the same resolution equivalents to smaller FoV.
    # Here, we ensures the FoV is the same as the StyleGAN2 baseline
    fov_scale = config.train_params.full_size / real_data_res
    raw_gen_res = int(np.ceil(eval_gen_res * fov_scale))

    if args.seq_inference:
        assert (not hasattr(config.train_params, "styleGAN2_baseline")) or (not config.train_params.styleGAN2_baseline)
        assert args.scale > 1, "Set sequential inference with scale==1 is meaningless"
        use_seq_inf = True
    else:
        use_seq_inf = False

    """
    Create dataloader and generator
    """
    if args.img_folder is not None:
        postprocessing_params = [
            ["assert", eval_gen_res],
            ["resize", real_data_res],
        ]
    else:
        postprocessing_params = [
            ["scale", 1 / fov_scale],
            ["crop", eval_gen_res],
            ["resize", real_data_res],
        ]
    fake_generator = \
        QuantEvalSampleGenerator(
            g_ema, 
            latent_sampler, 
            img_folder=args.img_folder, # if applicable
            output_size=raw_gen_res, 
            use_seq_inf=use_seq_inf,
            postprocessing_params=postprocessing_params,
            fid_type=args.type,
            device=device, 
            config=config,
            use_pil_resize=args.use_pil_resize)


    stats_key = "benchmark-{}-{}-RealRes{}".format(
        args.type, config.data_params.dataset, real_data_res)
    # FID statistics can be different for different PyTorch version, not sure about cuda
    stats_key += f"_PT{torch.__version__}_cu{torch.version.cuda}"
    fid_cache_path = os.path.join(".fid-cache/", stats_key+".pkl")
    if os.path.exists(fid_cache_path):
        if args.clear_fid_cache:
            os.remove(fid_cache_path)
            use_cache = False
        else:
            use_cache = True
    else:
        use_cache = False

    if not use_cache:
        dataset = MultiResolutionDataset(
            split="train",
            config=config,
            is_training=False,
            # return "full" of real full images and crop on-the-fly
            disable_extra_cropping=True,
            simple_return_full=True,
            override_full_size=real_data_res) 
        real_dataloader = QuantEvalDataLoader(dataset, real_data_res, device, config)
    else:
        real_dataloader = None

    """
    Eval
    """
    st = time.time()
    if args.metric == "is":
        assert args.scale == 1, "We didn't implement scaleinv IS."
        n_batch = int(np.ceil(config.test_params.n_fid_sample / config.train_params.batch_size))
        all_imgs = []
        for img_batch in tqdm(fake_generator(n_batch), total=n_batch):
            img_batch = ((img_batch + 1) / 2).cpu() # [-1, 1] => [0, 1]
            all_imgs.append(img_batch)
        all_imgs = torch.cat(all_imgs, 0)
        is_mean, is_std = inception_score(all_imgs, device="cuda", batch_size=config.train_params.batch_size, resize=False, splits=10)
        print(" [*] IS time spend {}".format(args.type, time.time()-st))
        print(" [*] IS at eval_gen_res {} is {}+-{} (ckpt patch FID = {})".format(
            eval_gen_res, is_mean, is_std, config.var.best_fid))
    elif args.metric == "fid":
        if args.type == "spatial":
            fid = eval_fid(
                real_dataloader, fake_generator, inception, stats_key, None, device, config, 
                spatial_partition_cat=True, assert_eval_shape=real_data_res)
        elif args.type in {"scaleinv", "alis"}:
            fid = eval_fid(
                real_dataloader, fake_generator, inception, stats_key, None, device, config, 
                spatial_partition_cat=False, assert_eval_shape=real_data_res)
        else:
            raise NotImplementedError("Unknown FID variant {}".format(args.type))
        print(" [*] {} FID time spend {}".format(args.type, time.time()-st))
        print(" [*] FID (type {}) at eval_gen_res {} is {} (ckpt patch FID = {})".format(
            args.type, eval_gen_res, fid, config.var.best_fid))

    """
    Setup Logging
    """
    if args.metric == "is":
        log_root = os.path.join("logs-quant", "IS")
        filename = f"EvalGenRes{eval_gen_res}-Exp-{config.var.exp_name}.txt"
        score = "{:.6f}+-{:.6f}\n".format(is_mean, is_std)
    else:
        log_root = os.path.join("logs-quant", "FID-"+args.type)
        filename = f"Scale{args.scale}-EvalGenRes{eval_gen_res}-Exp-{config.var.exp_name}.txt"
        score = "{:.6f}\n".format(fid)

    if not os.path.exists(log_root):
        os.makedirs(log_root)
    with open(os.path.join(log_root, filename), "a") as lf:
        lf.write(score)
Exemplo n.º 11
0
def train(learning_rate, lambda_mse):
    print(
        f"learning_rate={learning_rate:.4f}", f"lambda_mse={lambda_mse:.4f}",
    )

    transform = transforms.Compose(
        [
            transforms.Resize(128),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )
    batch_size = 72
    data_path = "/home/hans/trainsets/cyphis"
    name = os.path.splitext(os.path.basename(data_path))[0]
    dataset = MultiResolutionDataset(data_path, transform, 256)
    dataloader = data.DataLoader(
        dataset, batch_size=batch_size, sampler=data.RandomSampler(dataset), num_workers=12, drop_last=True,
    )
    loader = sample_data(dataloader)
    sample_imgs = next(loader)[:24]
    wandb.log({"Real Images": [wandb.Image(utils.make_grid(sample_imgs, nrow=6, normalize=True, range=(-1, 1)))]})

    vae, vae_optim = None, None
    vae = ConvSegNet().to(device)
    vae_optim = th.optim.Adam(vae.parameters(), lr=learning_rate)

    vgg = VGGLoss()

    sample_z = th.randn(size=(24, 512, 16, 16))
    sample_z /= sample_z.abs().max()

    scores = []
    num_iters = 100_000
    pbar = tqdm(range(num_iters), smoothing=0.1)
    for i in pbar:
        vae.train()

        real = next(loader).to(device)

        z = vae.encode(real)
        fake = vae.decode(z)

        vgg_loss = vgg(fake, real)

        mse_loss = th.sqrt((fake - real).pow(2).mean())

        # diff = fake - real
        # recons_loss = recons_alpha * diff + th.log(1.0 + th.exp(-2 * recons_alpha * diff)) - th.log(th.tensor(2.0))
        # recons_loss = (1.0 / recons_alpha) * recons_loss.mean()
        # recons_loss = recons_loss if not th.isinf(recons_loss).any() else 0

        # x, y = z.chunk(2)
        # align_loss = align(x, y, alpha=align_alpha)
        # unif_loss = -(uniform(x, t=unif_t) + uniform(y, t=unif_t)) / 2.0

        loss = (
            vgg_loss
            + lambda_mse * mse_loss
            # + lambda_recons * recons_loss
            # + lambda_align * align_loss
            # + lambda_unif * unif_loss
        )
        # print(vgg_loss.detach().cpu().item())
        # print(lambda_mse * mse_loss.detach().cpu().item())
        # # print(lambda_recons * recons_loss.detach().cpu().item())
        # print(lambda_align * align_loss.detach().cpu().item())
        # print(lambda_unif * unif_loss.detach().cpu().item())

        loss_dict = {
            "Total": loss,
            "MSE": mse_loss,
            "VGG": vgg_loss,
            # "Reconstruction": recons_loss,
            # "Alignment": align_loss,
            # "Uniformity": unif_loss,
        }

        vae.zero_grad()
        loss.backward()
        vae_optim.step()

        wandb.log(loss_dict)
        # pbar.set_description(" ".join())

        with th.no_grad():
            if i % int(num_iters / 100) == 0 or i + 1 == num_iters:
                vae.eval()

                sample = vae(sample_imgs.to(device))
                grid = utils.make_grid(sample, nrow=6, normalize=True, range=(-1, 1))
                del sample
                wandb.log({"Reconstructed Images VAE": [wandb.Image(grid, caption=f"Step {i}")]})

                sample = vae.decode(sample_z.to(device))
                grid = utils.make_grid(sample, nrow=6, normalize=True, range=(-1, 1))
                del sample
                wandb.log({"Generated Images VAE": [wandb.Image(grid, caption=f"Step {i}")]})

                gc.collect()
                th.cuda.empty_cache()

                th.save(
                    {"vae": vae.state_dict(), "vae_optim": vae_optim.state_dict()},
                    f"/home/hans/modelzoo/maua-sg2/vae-{name}-{wandb.run.dir.split('/')[-1].split('-')[-1]}.pt",
                )

        if th.isnan(loss).any():
            print("NaN losses, exiting...")
            wandb.log({"Total": 27000})
            return
Exemplo n.º 12
0
        )

        discriminator = nn.parallel.DistributedDataParallel(
            discriminator,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
        )

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])

    dataset = MultiResolutionDataset(args.path,
                                     transform,
                                     args.resolution,
                                     condition_path=args.condition_path)
    loader = data.DataLoader(
        dataset,
        batch_size=args.batch,
        sampler=data_sampler(dataset,
                             shuffle=True,
                             distributed=args.distributed),
        drop_last=True,
        num_workers=args.num_workers,
    )

    train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device)
Exemplo n.º 13
0
def train(args, generator, discriminator):
    step = int(math.log2(args.max_size)) - 2  #-> 1
    resolution = 4 * 2**step
    batch_size = args.batch.get(resolution, args.batch_default)
    dataset = MultiResolutionDataset(args.path,
                                     transform,
                                     resolution=resolution)

    loader = sample_data(dataset, batch_size, resolution)
    data_loader = iter(loader)

    adjust_lr(g_optimizer, args.lr.get(resolution, 0.001))
    adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))

    pbar = tqdm(range(3000000))

    requires_grad(generator, False)
    requires_grad(discriminator, True)

    disc_loss_val = 0
    gen_loss_val = 0
    grad_loss_val = 0

    alpha = 0
    used_sample = 0  #-> how many images has been used

    max_step = int(math.log2(args.max_size)) - 2  #-> log2(1024) - 2 = 8
    final_progress = False

    for i in pbar:
        discriminator.zero_grad()

        alpha = min(1, 1 / args.phase *
                    (used_sample + 1))  #-> min(1, (cur+1)/60_0000)
        #-> when more than 60_0000 sampels is used, alpha will be in const to 1.0
        #-> which means we the "skip_rgb" will not be applied

        if (resolution == args.init_size
                and args.ckpt is None) or final_progress:
            alpha = 1
        #-> also, if initially, no previous outputs for skip-connection

        if used_sample > args.phase * 2:  #-> if > 1_200_000
            ## num_of_epoch_each_phase = args.phase * 2 / training_dataset_size
            used_sample = 0
            step += 1

            if step > max_step:
                step = max_step
                final_progress = True
                ckpt_step = step + 1

            else:
                alpha = 0
                ckpt_step = step

            resolution = 4 * 2**step

            loader = sample_data(
                dataset, args.batch.get(resolution, args.batch_default),
                resolution)
            data_loader = iter(loader)

            torch.save(
                {
                    'generator': generator.module.state_dict(),
                    'discriminator': discriminator.module.state_dict(),
                    'g_optimizer': g_optimizer.state_dict(),
                    'd_optimizer': d_optimizer.state_dict(),
                    'g_running': g_running.state_dict(),
                }, r'checkpoint/train_step-{}.model'.format(ckpt_step))

            adjust_lr(g_optimizer, args.lr.get(resolution, 0.001))
            adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))

        #### update discriminator
        try:
            real_image = next(data_loader)

        except (OSError, StopIteration):
            data_loader = iter(loader)
            real_image = next(data_loader)

        used_sample += real_image.shape[0]

        b_size = real_image.size(0)
        coords = coord_base.repeat(b_size, 1)
        select = np.hstack([[i * b_size + j for i in range(4)]
                            for j in range(b_size)])
        real_image = real_image.cuda()

        if args.loss == 'wgan-gp':
            real_predict = discriminator(real_image, step=step, alpha=alpha)
            real_predict = real_predict.mean() - 0.001 * (real_predict**
                                                          2).mean()
            (-real_predict).backward()

        elif args.loss == 'r1':
            real_image.requires_grad = True
            real_scores = discriminator(real_image, step=step, alpha=alpha)
            real_predict = F.softplus(-real_scores).mean()
            real_predict.backward(retain_graph=True)

            grad_real = grad(outputs=real_scores.sum(),
                             inputs=real_image,
                             create_graph=True)[0]
            grad_penalty = (grad_real.view(grad_real.size(0),
                                           -1).norm(2, dim=1)**2).mean()
            grad_penalty = 10 / 2 * grad_penalty
            grad_penalty.backward()
            if i % 10 == 0:
                grad_loss_val = grad_penalty.item()

        if args.mixing and random.random() < 0.9:
            gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn(
                4, b_size, code_size - 2, device='cuda').chunk(4, 0)

            gen_in11 = gen_in11.squeeze(0)
            gen_in11 = torch.cat([gen_in11.repeat(4, 1)[select], coords],
                                 dim=1)

            gen_in12 = gen_in12.squeeze(0)
            gen_in12 = torch.cat([gen_in12.repeat(4, 1)[select], coords],
                                 dim=1)

            gen_in21 = gen_in21.squeeze(0)
            gen_in21 = torch.cat([gen_in21.repeat(4, 1)[select], coords],
                                 dim=1)

            gen_in22 = gen_in22.squeeze(0)
            gen_in22 = torch.cat([gen_in22.repeat(4, 1)[select], coords],
                                 dim=1)

            gen_in1 = [gen_in11, gen_in12]
            gen_in2 = [gen_in21, gen_in22]

        else:
            gen_in1, gen_in2 = torch.randn(2,
                                           b_size,
                                           code_size - 2,
                                           device='cuda').chunk(
                                               2,
                                               0  # 512
                                           )
            gen_in1 = gen_in1.squeeze(0)  # (B, 254)
            gen_in2 = gen_in2.squeeze(0)  # (B, 254)

            # repeat and copy
            gen_in1 = torch.cat([gen_in1.repeat(4, 1)[select], coords], dim=1)
            gen_in2 = torch.cat([gen_in2.repeat(4, 1)[select], coords], dim=1)

        fake_image = generator(gen_in1, step=step - 1, alpha=alpha)

        fake_image_up = torch.cat([fake_image[0::4], fake_image[1::4]], dim=3)
        fake_image_dn = torch.cat([fake_image[2::4], fake_image[3::4]], dim=3)
        fake_image = torch.cat([fake_image_up, fake_image_dn], dim=2)

        fake_predict = discriminator(fake_image, step=step, alpha=alpha)

        if args.loss == 'wgan-gp':
            fake_predict = fake_predict.mean()
            fake_predict.backward()

            eps = torch.rand(b_size, 1, 1, 1).cuda()
            x_hat = eps * real_image.data + (1 - eps) * fake_image.data
            x_hat.requires_grad = True
            hat_predict = discriminator(x_hat, step=step, alpha=alpha)
            grad_x_hat = grad(outputs=hat_predict.sum(),
                              inputs=x_hat,
                              create_graph=True)[0]
            grad_penalty = (
                (grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) -
                 1)**2).mean()
            grad_penalty = 10 * grad_penalty
            grad_penalty.backward()
            if i % 10 == 0:
                grad_loss_val = grad_penalty.item()
                disc_loss_val = (real_predict - fake_predict).item()

        elif args.loss == 'r1':
            fake_predict = F.softplus(fake_predict).mean()
            fake_predict.backward()
            if i % 10 == 0:
                disc_loss_val = (real_predict + fake_predict).item()

        d_optimizer.step()

        #### update generator
        if (i + 1) % n_critic == 0:
            generator.zero_grad()

            requires_grad(generator, True)
            requires_grad(discriminator, False)

            fake_image = generator(gen_in2, step=step - 1, alpha=alpha)

            fake_image_up = torch.cat([fake_image[0::4], fake_image[1::4]],
                                      dim=3)
            fake_image_dn = torch.cat([fake_image[2::4], fake_image[3::4]],
                                      dim=3)
            fake_image = torch.cat([fake_image_up, fake_image_dn], dim=2)

            predict = discriminator(fake_image, step=step, alpha=alpha)

            if args.loss == 'wgan-gp':
                loss = -predict.mean()

            elif args.loss == 'r1':
                loss = F.softplus(-predict).mean()

            if i % 10 == 0:
                gen_loss_val = loss.item()

            loss.backward()
            g_optimizer.step()
            accumulate(g_running, generator.module)

            requires_grad(generator, False)
            requires_grad(discriminator, True)

        #### validation
        if (i + 1) % 100 == 0:
            images = []

            gen_i, gen_j = args.gen_sample.get(resolution, (10, 5))
            coords = coord_base.repeat(gen_j, 1)
            select = np.hstack([[i * gen_j + j for i in range(4)]
                                for j in range(gen_j)])

            with torch.no_grad():
                for ii in range(gen_i):
                    style = torch.randn(gen_j,
                                        code_size - 2).cuda().repeat(4,
                                                                     1)[select]
                    style = torch.cat([style, coords], dim=1)
                    image = g_running(style, step=step - 1,
                                      alpha=alpha).data.cpu()

                    image_up = torch.cat([image[0::4], image[1::4]], dim=3)
                    image_dn = torch.cat([image[2::4], image[3::4]], dim=3)
                    image = torch.cat([image_up, image_dn], dim=2)

                    images.append(image)

            utils.save_image(
                torch.cat(images, 0),
                r'sample/%06d.png' % (i + 1),
                nrow=gen_i,
                normalize=True,
                range=(-1, 1),
            )

        if (i + 1) % 10000 == 0:
            torch.save(g_running.state_dict(),
                       r'checkpoint/%06d.model' % (i + 1))

        state_msg = (
            r'Size: {}; G: {:.3f}; D: {:.3f}; Grad: {:.3f}; Alpha: {:.5f}'.
            format(4 * 2**step, gen_loss_val, disc_loss_val, grad_loss_val,
                   alpha))

        pbar.set_description(state_msg)
Exemplo n.º 14
0
def train(args, generator, discriminator):
    step = int(math.log2(args.max_size)) - 2 #-> 1
    resolution = 4 * 2 ** step
    batch_size = args.batch.get(resolution, args.batch_default)
    dataset = MultiResolutionDataset(args.path, transform, resolution=resolution)
    
    loader = sample_data(
        dataset, batch_size, resolution
    )
    data_loader = iter(loader)

    adjust_lr(g_optimizer, args.lr.get(resolution, 0.001))
    adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))

    pbar = tqdm(range(3000000))

    requires_grad(generator, False)
    requires_grad(discriminator, True)

    disc_loss_val = 0
    gen_loss_val = 0
    grad_loss_val = 0

    alpha = 0
    used_sample = 0 #-> how many images has been used

    max_step = int(math.log2(args.max_size)) - 2 #-> log2(1024) - 2 = 8
    final_progress = False

    for i in pbar:
        discriminator.zero_grad()

        alpha = min(1, 1 / args.phase * (used_sample + 1)) #-> min(1, (cur+1)/60_0000)
        #-> when more than 60_0000 sampels is used, alpha will be in const to 1.0
        #-> which means we the "skip_rgb" will not be applied

        if (resolution == args.init_size and args.ckpt is None) or final_progress:
            alpha = 1
        #-> also, if initially, no previous outputs for skip-connection

        if used_sample > args.phase * 2: #-> if > 1_200_000
            ## num_of_epoch_each_phase = args.phase * 2 / training_dataset_size
            used_sample = 0
            
            step += 1

            if step > max_step:
                step = max_step
                final_progress = True
                ckpt_step = step + 1

            else:
                alpha = 0
                ckpt_step = step
            

            resolution = 4 * 2 ** step_D

            loader = sample_data(
                dataset, args.batch.get(resolution, args.batch_default), resolution
            )
            data_loader = iter(loader)

            torch.save(
                {
                    'generator': generator.module.state_dict(),
                    'discriminator': discriminator.module.state_dict(),
                    'g_optimizer': g_optimizer.state_dict(),
                    'd_optimizer': d_optimizer.state_dict(),
                    'g_running': g_running.state_dict(),
                }, r'checkpoint_coco/train_step-{}.model'.format(ckpt_step))

            adjust_lr(g_optimizer, args.lr.get(resolution, 0.001))
            adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))

        #### update discriminator
        try:
            real_image = next(data_loader)

        except (OSError, StopIteration):
            data_loader = iter(loader)
            real_image = next(data_loader)

        used_sample += real_image.shape[0]
        real_image = real_image.cuda()

        b_size = real_image.size(0)
        select = np.hstack([[i*b_size+j for i in range(num_micro_in_macro)] for j in range(b_size)])
        # get sample coords
        coord_handler.batch_size = b_size
        patch_handler.batch_size = b_size
        d_macro_coord_real, g_micro_coord_real, _ = coord_handler._euclidean_sample_coord()
        d_macro_coord_fake1, g_micro_coord_fake1, _ = coord_handler._euclidean_sample_coord()
        d_macro_coord_fake2, g_micro_coord_fake2, _ = coord_handler._euclidean_sample_coord()
        
        d_macro_coord_real = torch.from_numpy(d_macro_coord_real).float().cuda()
        d_macro_coord_fake1, g_micro_coord_fake1 = torch.from_numpy(d_macro_coord_fake1).float().cuda(), torch.from_numpy(g_micro_coord_fake1).float().cuda()
        d_macro_coord_fake2, g_micro_coord_fake2 = torch.from_numpy(d_macro_coord_fake2).float().cuda(), torch.from_numpy(g_micro_coord_fake2).float().cuda()

        real_macro = micros_to_macro(patch_handler.crop_micro_from_full_gpu(real_image, g_micro_coord_real[:, 1:2], g_micro_coord_real[:, 0:1]), config["data_params"]["ratio_macro_to_micro"])
        
        if args.loss == 'wgan-gp':
            real_predict, real_H = discriminator(real_macro, d_macro_coord_real, step=step_D, alpha=alpha)
            real_predict = real_predict.mean() - 0.001 * (real_predict ** 2).mean()
            
            sp_loss_real = criterion_mse(spatial_predictor(real_H), d_macro_coord_real) * coord_loss_w
            (-real_predict+sp_loss_real).backward()

        elif args.loss == 'r1':
            real_macro.requires_grad = True
            real_scores, real_H = discriminator(real_macro, d_macro_coord_real, step=step_D, alpha=alpha)
            real_predict = F.softplus(-real_scores).mean()
            sp_loss_real = criterion_mse(spatial_predictor(real_H), d_macro_coord_real) * coord_loss_w
            (real_predict+sp_loss_real).backward(retain_graph=True)

            grad_real = grad(
                outputs=real_scores.sum(), inputs=real_macro, create_graph=True
            )[0]
            grad_penalty = (
                grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
            ).mean()
            grad_penalty = 10 / 2 * grad_penalty
            grad_penalty.backward()
            if i%10 == 0:
                grad_loss_val = grad_penalty.item()

        if args.mixing and random.random() < 0.9:
            gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn(
                4, b_size, code_size-2, device='cuda'
            ).chunk(4, 0)
            
            gen_in11 = gen_in11.squeeze(0)
            gen_in11 = torch.cat([gen_in11.repeat(num_micro_in_macro, 1)[select], g_micro_coord_fake1], dim=1)
            
            gen_in12 = gen_in12.squeeze(0)
            gen_in12 = torch.cat([gen_in12.repeat(num_micro_in_macro, 1)[select], g_micro_coord_fake1], dim=1)
            
            gen_in21 = gen_in21.squeeze(0)
            gen_in21 = torch.cat([gen_in21.repeat(num_micro_in_macro, 1)[select], g_micro_coord_fake2], dim=1)
            
            gen_in22 = gen_in22.squeeze(0)
            gen_in22 = torch.cat([gen_in22.repeat(num_micro_in_macro, 1)[select], g_micro_coord_fake2], dim=1)
            
            gen_in1 = [gen_in11, gen_in12]
            gen_in2 = [gen_in21, gen_in22]
            
            #print(gen_in11[:16])

        else:
            gen_in1, gen_in2 = torch.randn(2, b_size, code_size-2, device='cuda').chunk(
                2, 0                                  # 512
            )
            gen_in1 = gen_in1.squeeze(0)# (B, 254)
            gen_in2 = gen_in2.squeeze(0)# (B, 254)

            # repeat and copy
            gen_in1 = torch.cat([gen_in1.repeat(num_micro_in_macro, 1)[select], g_micro_coord_fake1], dim=1)
            gen_in2 = torch.cat([gen_in2.repeat(num_micro_in_macro, 1)[select], g_micro_coord_fake2], dim=1)
        
        fake_image = generator(gen_in1, step=step_G, alpha=alpha)
        fake_image = micros_to_macro(fake_image, config["data_params"]["ratio_macro_to_micro"])
        fake_predict, fake_H = discriminator(fake_image, d_macro_coord_fake1, step=step_D, alpha=alpha)
        sp_loss_fake = criterion_mse(spatial_predictor(fake_H), d_macro_coord_fake1) * coord_loss_w

        if args.loss == 'wgan-gp':
            fake_predict = fake_predict.mean()
            (fake_predict+sp_loss_fake).backward()

            eps = torch.rand(b_size, 1, 1, 1).cuda()
            x_hat = eps * real_image.data + (1 - eps) * fake_image.data
            x_hat.requires_grad = True
            hat_predict = discriminator(x_hat, step=step_D, alpha=alpha)
            grad_x_hat = grad(
                outputs=hat_predict.sum(), inputs=x_hat, create_graph=True
            )[0]
            grad_penalty = (
                (grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1) ** 2
            ).mean()
            grad_penalty = 10 * grad_penalty
            grad_penalty.backward()
            if i%10 == 0:
                grad_loss_val = grad_penalty.item()
                disc_loss_val = (real_predict - fake_predict).item()

        elif args.loss == 'r1':
            fake_predict = F.softplus(fake_predict).mean()
            (fake_predict+sp_loss_fake).backward()
            if i%10 == 0:
                disc_loss_val = (real_predict + fake_predict).item()

        d_optimizer.step()
        if i%10 == 0:
            spatial_loss_D_val = (sp_loss_real.item() + sp_loss_fake.item()) / 2


        #### update generator
        if (i + 1) % n_critic == 0:
            generator.zero_grad()

            requires_grad(generator, True)
            requires_grad(discriminator, False)

            fake_image = generator(gen_in2, step=step_G, alpha=alpha)
            fake_image = micros_to_macro(fake_image, config["data_params"]["ratio_macro_to_micro"])
            predict, H = discriminator(fake_image, d_macro_coord_fake2, step=step_D, alpha=alpha)
            spatial_loss = criterion_mse(spatial_predictor(H), d_macro_coord_fake2) * coord_loss_w

            if args.loss == 'wgan-gp':
                loss = -predict.mean()

            elif args.loss == 'r1':
                loss = F.softplus(-predict).mean()

            if i%10 == 0:
                gen_loss_val = loss.item()
                spatial_loss_G_val = spatial_loss.item()

            (loss+spatial_loss).backward()
            g_optimizer.step()
            accumulate(g_running, generator.module)

            requires_grad(generator, False)
            requires_grad(discriminator, True)


        #### validation
        if (i + 1) % 100 == 0:
            images = []

            gen_i, gen_j = args.gen_sample.get(resolution, (10, 5))
            
            coord_handler.batch_size = gen_i * gen_j
            _, g_micro_coord_val, _ = coord_handler._euclidean_sample_coord()
            g_micro_coord_val = torch.from_numpy(g_micro_coord_val).float().cuda()
            #print(g_micro_coord_val.shape)
            
            select = np.hstack([[i*gen_j+j for i in range(num_micro_in_macro)] for j in range(gen_j)])

            with torch.no_grad():
                for ii in range(gen_i):
                    style = torch.randn(gen_j, code_size-2).cuda().repeat(num_micro_in_macro, 1)[select]
                    #print(style.size())
                    coords = g_micro_coord_val[ii*gen_j*num_micro_in_macro:(ii+1)*gen_j*num_micro_in_macro]
                    #print(coords.size())
                    style = torch.cat([style, coords], dim=1)
                    
                    image = g_running(style, step=step_G, alpha=alpha).data.cpu()
                    image = micros_to_macro(image, config['data_params']['ratio_macro_to_micro'])
                    
                    images.append(
                        image
                    )

            utils.save_image(
                torch.cat(images, 0),
                r'sample_coco/%06d.png'%(i+1),
                nrow=gen_i,
                normalize=True,
                range=(-1, 1),
            )

        if (i + 1) % 10000 == 0:
            torch.save(
                g_running.state_dict(), r'checkpoint_coco/%06d.model'%(i+1)
            )

        state_msg = (
            r'Size: {}; G: {:.3f}; D: {:.3f}; Grad: {:.3f}; sp_G: {:.3f}; sp_D: {:.3f}; Alpha: {:.5f}'.format(4 * 2 ** step, gen_loss_val, disc_loss_val, grad_loss_val, spatial_loss_G_val, spatial_loss_D_val, alpha)
        )

        pbar.set_description(state_msg)
Exemplo n.º 15
0
def main(args, myargs):
    code_size = 512
    batch_size = 16
    n_critic = 1

    generator = nn.DataParallel(StyledGenerator(code_size)).cuda()
    discriminator = nn.DataParallel(
        Discriminator(from_rgb_activate=not args.no_from_rgb_activate)).cuda()
    g_running = StyledGenerator(code_size).cuda()
    g_running.train(False)

    g_optimizer = optim.Adam(generator.module.generator.parameters(),
                             lr=args.lr,
                             betas=(0.0, 0.99))
    g_optimizer.add_param_group({
        'params': generator.module.style.parameters(),
        'lr': args.lr * 0.01,
        'mult': 0.01,
    })
    d_optimizer = optim.Adam(discriminator.parameters(),
                             lr=args.lr,
                             betas=(0.0, 0.99))

    accumulate(g_running, generator.module, 0)

    if args.ckpt is not None:
        ckpt = torch.load(args.ckpt)

        generator.module.load_state_dict(ckpt['generator'])
        discriminator.module.load_state_dict(ckpt['discriminator'])
        g_running.load_state_dict(ckpt['g_running'])
        g_optimizer.load_state_dict(ckpt['g_optimizer'])
        d_optimizer.load_state_dict(ckpt['d_optimizer'])

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])

    dataset = MultiResolutionDataset(args.path, transform)

    if args.sched:
        args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
        args.batch = {
            4: 512,
            8: 256,
            16: 128,
            32: 64,
            64: 32,
            128: 32,
            256: 32
        }

    else:
        args.lr = {}
        args.batch = {}

    args.gen_sample = {512: (8, 4), 1024: (4, 2)}

    args.batch_default = 32

    train(args,
          dataset,
          generator,
          discriminator,
          g_optimizer=g_optimizer,
          d_optimizer=d_optimizer,
          g_running=g_running,
          code_size=code_size,
          n_critic=n_critic,
          myargs=myargs)
Exemplo n.º 16
0
    parser.add_argument('path', metavar='PATH', help='path to datset lmdb file')

    args = parser.parse_args()

    inception = load_patched_inception_v3()
    inception = nn.DataParallel(inception).eval().to(device)

    transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ]
    )

    dset = MultiResolutionDataset(config.data_params.lmdb_path, transform=transform, resolution=config.train_params.full_size)
    loader = DataLoader(dset, batch_size=config.train_params.batch_size, num_workers=4)

    features = extract_features(loader, inception, device).numpy()

    features = features[: params.test_params.n_fid_sample]

    print(f'extracted {features.shape[0]} features')

    mean = np.mean(features, 0)
    cov = np.cov(features, rowvar=False)

    name = os.path.splitext(os.path.basename(config.data_params.lmdb_path))[0]

    with open(f'inception_{name}.pkl', 'wb') as f:
        pickle.dump({'mean': mean, 'cov': cov, 'size': config.train_params.full_size, 'path': config.data_params.lmdb_path}, f)
Exemplo n.º 17
0
        t_optimizer.load_state_dict(ckpt['t_optimizer'])
        g_optimizer.load_state_dict(ckpt['g_optimizer'])
        d_optimizer.load_state_dict(ckpt['d_optimizer'])

    transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )

    if not os.path.exists(os.path.join(args.out, 'checkpoint')):
        os.makedirs(os.path.join(args.out, 'checkpoint'))
    
    dataset = MultiResolutionDataset(args.path, transform, max_length=24)
    inception_score = Inception_score(resize=True, splits=1)
    
    
    if args.sched:
        args.lr = {4: 1e-3, 8: 1e-3, 16: 5e-4, 32: 1e-4, 64: 1e-4, 128: 1e-4, 256: 1e-4}
        args.batch = {4: 64, 8: 64, 16: 64, 32: 32, 64: 32, 128: 16, 256: 16}

    else:
        args.lr = {}
        args.batch = {}

    args.gen_sample = {512: (8, 4), 1024: (4, 2)}

    args.batch_default = 32
Exemplo n.º 18
0
def train(latent_dim, num_repeats, learning_rate, lambda_vgg, lambda_mse):
    print(
        f"latent_dim={latent_dim:.4f}",
        f"num_repeats={num_repeats:.4f}",
        f"learning_rate={learning_rate:.4f}",
        f"lambda_vgg={lambda_vgg:.4f}",
        f"lambda_mse={lambda_mse:.4f}",
    )

    transform = transforms.Compose([
        transforms.Resize(128),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])
    batch_size = 72
    data_path = "/home/hans/trainsets/cyphis"
    name = os.path.splitext(os.path.basename(data_path))[0]
    dataset = MultiResolutionDataset(data_path, transform, 256)
    dataloader = data.DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=data.RandomSampler(dataset),
        num_workers=12,
        drop_last=True,
    )
    loader = sample_data(dataloader)
    sample_imgs = next(loader)[:24]
    wandb.log({
        "Real Images": [
            wandb.Image(
                utils.make_grid(sample_imgs,
                                nrow=6,
                                normalize=True,
                                range=(0, 1)))
        ]
    })

    vae, vae_optim = None, None
    vae = InceptionVAE(latent_dim=latent_dim,
                       repeat_per_block=num_repeats).to(device)
    vae_optim = th.optim.Adam(vae.parameters(), lr=learning_rate)

    vgg = VGGLoss()

    # sample_z = th.randn(size=(24, 512))

    scores = []
    num_iters = 100_000
    pbar = tqdm(range(num_iters), smoothing=0.1)
    for i in pbar:
        vae.train()

        real = next(loader).to(device)

        fake, mu, log_var = vae(real)

        bce = F.binary_cross_entropy(fake, real, size_average=False)
        kld = -0.5 * th.sum(1 + log_var - mu.pow(2) - log_var.exp())
        vgg_loss = vgg(fake, real)
        mse_loss = th.sqrt((fake - real).pow(2).mean())

        loss = bce + kld + lambda_vgg * vgg_loss + lambda_mse * mse_loss

        loss_dict = {
            "Total": loss,
            "BCE": bce,
            "Kullback Leibler Divergence": kld,
            "MSE": mse_loss,
            "VGG": vgg_loss,
        }

        vae.zero_grad()
        loss.backward()
        vae_optim.step()

        wandb.log(loss_dict)

        with th.no_grad():
            if i % int(num_iters / 100) == 0 or i + 1 == num_iters:
                vae.eval()

                sample, _, _ = vae(sample_imgs.to(device))
                grid = utils.make_grid(sample,
                                       nrow=6,
                                       normalize=True,
                                       range=(0, 1))
                del sample
                wandb.log({
                    "Reconstructed Images VAE":
                    [wandb.Image(grid, caption=f"Step {i}")]
                })

                sample = vae.sampling()
                grid = utils.make_grid(sample,
                                       nrow=6,
                                       normalize=True,
                                       range=(0, 1))
                del sample
                wandb.log({
                    "Generated Images VAE":
                    [wandb.Image(grid, caption=f"Step {i}")]
                })

                gc.collect()
                th.cuda.empty_cache()

                th.save(
                    {
                        "vae": vae.state_dict(),
                        "vae_optim": vae_optim.state_dict()
                    },
                    f"/home/hans/modelzoo/maua-sg2/vae-{name}-{wandb.run.dir.split('/')[-1].split('-')[-1]}.pt",
                )

        if th.isnan(loss).any() or th.isinf(loss).any():
            print("NaN losses, exiting...")
            print({
                "Total": loss,
                "\nBCE": bce,
                "\nKullback Leibler Divergence": kld,
                "\nMSE": mse_loss,
                "\nVGG": vgg_loss,
            })
            wandb.log({"Total": 27000})
            return
                        metavar='PATH',
                        help='path to datset lmdb file')

    args = parser.parse_args()

    inception = load_patched_inception_v3()
    inception = nn.DataParallel(inception).eval().to(device)

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dset = MultiResolutionDataset(args.path,
                                  transform=transform,
                                  resolution=args.size)
    loader = DataLoader(dset, batch_size=args.batch, num_workers=4)

    features = extract_features(loader, inception, device).numpy()

    features = features[:args.n_sample]

    print(f'extracted {features.shape[0]} features')

    mean = np.mean(features, 0)
    cov = np.cov(features, rowvar=False)

    name = os.path.splitext(os.path.basename(args.path))[0]

    with open(f'inception_{name}.pkl', 'wb') as f:
Exemplo n.º 20
0
    args = parser.parse_args()

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    inception = InceptionV3().cuda()

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])

    dataset = MultiResolutionDataset(f'./dataset/{args.dataset}_lmdb',
                                     transform)
    loader = sample_data(dataset, args.batch_size, args.image_size)

    pbar = tqdm(total=len(dataset))

    acts = []
    for real_index, real_image in loader:
        real_image = real_image.cuda()
        with torch.no_grad():
            out = inception(real_image)
            out = out[0].squeeze(-1).squeeze(-1)
        acts.append(out.cpu().numpy())
        pbar.update(len(real_image))
    acts = np.concatenate(acts, axis=0)

    with open(f'dataset/{args.dataset}_acts.pickle', 'wb') as handle:
Exemplo n.º 21
0
        discriminator = nn.parallel.DistributedDataParallel(
            discriminator,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
            find_unused_parameters=True,
        )

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    dataset = MultiResolutionDataset(args.path)
    loader = data.DataLoader(
        dataset,
        batch_size=args.batch,
        sampler=data_sampler(dataset,
                             shuffle=True,
                             distributed=args.distributed),
        drop_last=True,
    )

    if get_rank() == 0 and wandb is not None and args.wandb:
        wandb.init(project='stylegan 2')

    train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device)
Exemplo n.º 22
0
            broadcast_buffers=False,
        )

        drs_discriminator = nn.parallel.DistributedDataParallel(
            drs_discriminator,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
        )

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])
    dataset = MultiResolutionDataset(args.root, transform, args.size)

    logit_path = f'./exp_results/{args.baseline_exp_name}/logits_netD.pkl'
    print(f'Use logit from: {logit_path}')
    logits = pickle.load(open(logit_path, "rb"))

    window = 5000
    score_start_step = (args.p1_step - window)
    score_end_step = args.p1_step + 1
    score_dict = calculate_scores(logits,
                                  start_epoch=score_start_step,
                                  end_epoch=score_end_step)

    sample_weights = score_dict[args.resample_score]

    def print_stats(sw):
Exemplo n.º 23
0
    ### prepare experiments ###
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    ### load dataset ###

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])

    dataset = MultiResolutionDataset(f'./dataset/{args.dataset}_lmdb',
                                     transform,
                                     resolution=args.image_size)

    ### load G and D ###

    if args.supervised:
        G_target = nn.DataParallel(
            StyledGenerator(code_size,
                            dataset_size=len(dataset),
                            embed_dim=code_size)).cuda()
        G_running_target = StyledGenerator(code_size,
                                           dataset_size=len(dataset),
                                           embed_dim=code_size).cuda()
        G_running_target.train(False)
        accumulate(G_running_target, G_target.module, 0)
    else:
Exemplo n.º 24
0
	### load G and D ###

	gen1, dis1 = load_network(f'./checkpoint/{args.ckpt1}')
	gen2, dis2 = load_network(f'./checkpoint/{args.ckpt2}')
	gen3, dis3 = load_network(f'./checkpoint/{args.ckpt3}')

	### load dataset ###

	transform = transforms.Compose([
		transforms.RandomHorizontalFlip(),
		transforms.ToTensor(),
		transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
	])

	data1 = MultiResolutionDataset(f'./dataset/{args.data1}_lmdb', transform, resolution=args.image_size)
	data2 = MultiResolutionDataset(f'./dataset/{args.data2}_lmdb', transform, resolution=args.image_size)
	data3 = MultiResolutionDataset(f'./dataset/{args.data3}_lmdb', transform, resolution=args.image_size)

	step = int(math.log2(args.image_size)) - 2
	resolution = 4 * 2 ** step
	batch_size = 10

	### run experiment ###

	# acc11, threshold11 = test(dis1, data1, gen1)
	acc11, threshold11 = 77.15, 0.5685
	acc12, threshold12 = test(dis1, data2, gen2)
	acc13, threshold13 = test(dis1, data3, gen3)
	acc21, threshold21 = test(dis2, data1, gen1)
	acc22, threshold22 = test(dis2, data2, gen2)
Exemplo n.º 25
0
        generator.proj.parameters(),
        lr=args.lr * g_reg_ratio,
        betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
    )
    d_optim = optim.Adam(
        discriminator.parameters(),
        lr=args.lr * d_reg_ratio / 2,
        betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
    )

    transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )

    dataset = MultiResolutionDataset(args.path, transform, args.size)
    loader = data.DataLoader(
        dataset,
        batch_size=args.batch,
        sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
        drop_last=True,
    )

    if get_rank() == 0 and wandb is not None and args.wandb:
        wandb.init(project='stylegan 2')

    train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
Exemplo n.º 26
0
            safe_load_state_dict(g_ema, ckpt["g_ema"])

            safe_load_state_dict(g_optim, ckpt["g_optim"])
            safe_load_state_dict(d_optim, ckpt["d_optim"])
        else:
            print(" [*] Did not find ckpt, fresh start!")
            config.var.start_iter = 0
            config.var.best_fid = 500
            config.var.mean_path_lengths = None 


        """
        Dataset
        """
        train_set = MultiResolutionDataset(
            split="train",
            config=config,
            is_training=True)
        valid_set = None
        #MultiResolutionDataset(
        #    os.path.join(dataset_root, "valid"), 
        #    is_training=False,
        #    config.train_params.full_size)
        train_set_fid = MultiResolutionDataset(
            split="train",
            config=config,
            is_training=False)

        loaders = {
            "train": make_nonstopping(data.DataLoader(
                train_set,
                batch_size=config.train_params.batch_size,