def sample(cfg, logger):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    torch.cuda.set_device(2)
    model = SoftIntroVAEModelTL(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
        style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER,
        beta_kl=cfg.MODEL.BETA_KL,
        beta_rec=cfg.MODEL.BETA_REC,
        beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1],
        scale=cfg.MODEL.SCALE)

    model.to(device)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl

    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters decoder:")
    print(count_parameters(decoder))

    logger.info("Trainable parameters encoder:")
    print(count_parameters(encoder))

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    checkpointer.load()

    model.eval()

    path = './make_figures/output'
    os.makedirs(path, exist_ok=True)
    os.makedirs(os.path.join(path, cfg.NAME), exist_ok=True)
    with torch.no_grad():
        generate_samples(cfg, model, path, 5, device=device)
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = SoftIntroVAEModelTL(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
        style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER,
        beta_kl=cfg.MODEL.BETA_KL,
        beta_rec=cfg.MODEL.BETA_REC,
        beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1],
        scale=cfg.MODEL.SCALE)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters decoder:")
    print(count_parameters(decoder))

    logger.info("Trainable parameters encoder:")
    print(count_parameters(encoder))

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        z, mu, _ = model.encode(x, layer_count - 1, 1)
        styles = model.mapping_fl(mu)
        return styles

    def decode(x):
        return model.decoder(x, layer_count - 1, 1, noise=True)

    path = cfg.DATASET.SAMPLES_PATH
    im_size = 2**(cfg.MODEL.LAYER_COUNT + 1)

    paths = list(os.listdir(path))

    paths = sorted(paths)
    random.seed(1)
    random.shuffle(paths)

    def make(paths):
        canvas = []
        with torch.no_grad():
            for filename in paths:
                img = np.asarray(Image.open(path + '/' + filename))
                if img.shape[2] == 4:
                    img = img[:, :, :3]
                im = img.transpose((2, 0, 1))
                x = torch.tensor(np.asarray(im, dtype=np.float32),
                                 device='cpu',
                                 requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]
                factor = x.shape[2] // im_size
                if factor != 1:
                    x = torch.nn.functional.avg_pool2d(x[None, ...], factor,
                                                       factor)[0]
                assert x.shape[2] == im_size
                latents = encode(x[None, ...].cuda())
                f = decode(latents)
                r = torch.cat([x[None, ...].detach().cpu(),
                               f.detach().cpu()],
                              dim=3)
                canvas.append(r)
        return canvas

    def chunker_list(seq, n):
        return [seq[i * n:(i + 1) * n] for i in range((len(seq) + n - 1) // n)]

    paths = chunker_list(paths, 8 * 3)

    path = './make_figures/output'
    os.makedirs(path, exist_ok=True)
    os.makedirs(os.path.join(path, cfg.NAME), exist_ok=True)

    for i, chunk in enumerate(paths):
        canvas = make(chunk)
        canvas = torch.cat(canvas, dim=0)

        save_path = './make_figures/output/%s/reconstructions_%d.png' % (
            cfg.NAME, i)
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        save_image(canvas * 0.5 + 0.5, save_path, nrow=3, pad_value=1.0)
def train(cfg, logger, local_rank, world_size, distributed):
    torch.cuda.set_device(local_rank)
    model = SoftIntroVAEModelTL(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
        style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER,
        beta_kl=cfg.MODEL.BETA_KL,
        beta_rec=cfg.MODEL.BETA_REC,
        beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1],
        scale=cfg.MODEL.SCALE)
    model.cuda(local_rank)
    model.train()

    if local_rank == 0:
        model_s = SoftIntroVAEModelTL(
            startf=cfg.MODEL.START_CHANNEL_COUNT,
            layer_count=cfg.MODEL.LAYER_COUNT,
            maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
            latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
            truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
            truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
            mapping_layers=cfg.MODEL.MAPPING_LAYERS,
            channels=cfg.MODEL.CHANNELS,
            generator=cfg.MODEL.GENERATOR,
            encoder=cfg.MODEL.ENCODER,
            beta_kl=cfg.MODEL.BETA_KL,
            beta_rec=cfg.MODEL.BETA_REC,
            beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1],
            scale=cfg.MODEL.SCALE)
        model_s.cuda(local_rank)
        model_s.eval()
        model_s.requires_grad_(False)

    if distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            broadcast_buffers=False,
            bucket_cap_mb=25,
            find_unused_parameters=True)
        model.device_ids = None

        decoder = model.module.decoder
        encoder = model.module.encoder
        mapping_tl = model.module.mapping_tl
        mapping_fl = model.module.mapping_fl
        dlatent_avg = model.module.dlatent_avg
    else:
        decoder = model.decoder
        encoder = model.encoder
        mapping_tl = model.mapping_tl
        mapping_fl = model.mapping_fl
        dlatent_avg = model.dlatent_avg

    count_parameters.print = lambda a: logger.info(a)

    num_vae_epochs = cfg.TRAIN.NUM_VAE

    logger.info("Trainable parameters decoder:")
    print(count_parameters(decoder))

    logger.info("Trainable parameters encoder:")
    print(count_parameters(encoder))

    arguments = dict()
    arguments["iteration"] = 0

    decoder_optimizer = LREQAdam([{
        'params': decoder.parameters()
    }, {
        'params': mapping_fl.parameters()
    }],
                                 lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                 betas=(cfg.TRAIN.ADAM_BETA_0,
                                        cfg.TRAIN.ADAM_BETA_1),
                                 weight_decay=0)

    encoder_optimizer = LREQAdam([{
        'params': encoder.parameters()
    }, {
        'params': mapping_tl.parameters()
    }],
                                 lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                 betas=(cfg.TRAIN.ADAM_BETA_0,
                                        cfg.TRAIN.ADAM_BETA_1),
                                 weight_decay=0)

    scheduler = ComboMultiStepLR(optimizers={
        'encoder_optimizer': encoder_optimizer,
        'decoder_optimizer': decoder_optimizer
    },
                                 milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,
                                 gamma=cfg.TRAIN.LEARNING_DECAY_RATE,
                                 reference_batch_size=32,
                                 base_lr=cfg.TRAIN.LEARNING_RATES)

    model_dict = {
        'discriminator': encoder,
        'generator': decoder,
        'mapping_fl': mapping_fl,
        'mapping_tl': mapping_tl,
        'dlatent_avg': dlatent_avg
    }
    if local_rank == 0:
        model_dict['discriminator_s'] = model_s.encoder
        model_dict['generator_s'] = model_s.decoder
        model_dict['mapping_fl_s'] = model_s.mapping_fl
        model_dict['mapping_tl_s'] = model_s.mapping_tl

    tracker = LossTracker(cfg.OUTPUT_DIR)

    checkpointer = Checkpointer(cfg,
                                model_dict, {
                                    'encoder_optimizer': encoder_optimizer,
                                    'decoder_optimizer': decoder_optimizer,
                                    'scheduler': scheduler,
                                    'tracker': tracker
                                },
                                logger=logger,
                                save=local_rank == 0)

    extra_checkpoint_data = checkpointer.load()
    logger.info("Starting from epoch: %d" % (scheduler.start_epoch()))

    arguments.update(extra_checkpoint_data)

    layer_to_resolution = decoder.layer_to_resolution

    dataset = TFRecordsDataset(cfg,
                               logger,
                               rank=local_rank,
                               world_size=world_size,
                               buffer_size_mb=1024,
                               channels=cfg.MODEL.CHANNELS)

    rnd = np.random.RandomState(3456)
    latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE)
    samplez = torch.tensor(latents).float().cuda()

    lod2batch = lod_driver.LODDriver(cfg,
                                     logger,
                                     world_size,
                                     dataset_size=len(dataset) * world_size)

    if cfg.DATASET.SAMPLES_PATH:
        path = cfg.DATASET.SAMPLES_PATH
        src = []
        with torch.no_grad():
            for filename in list(os.listdir(path))[:32]:
                img = np.asarray(Image.open(os.path.join(path, filename)))
                if img.shape[2] == 4:
                    img = img[:, :, :3]
                im = img.transpose((2, 0, 1))
                x = torch.tensor(np.asarray(im, dtype=np.float32),
                                 requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]
                src.append(x)
            sample = torch.stack(src)
    else:
        dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 32)
        sample = next(make_dataloader(cfg, logger, dataset, 32, local_rank))
        sample = (sample / 127.5 - 1.)

    lod2batch.set_epoch(scheduler.start_epoch(),
                        [encoder_optimizer, decoder_optimizer])

    kls_real = []
    kls_fake = []
    rec_errs = []

    best_fid = None
    # best_fid = 20.0

    for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
        new_beta_neg = cfg.MODEL.BETA_NEG[lod2batch.lod]
        if distributed:
            if model.module.beta_neg != new_beta_neg:
                model.module.beta_neg = new_beta_neg
                print("beta negative changed to:", new_beta_neg)
        else:
            if model.beta_neg != new_beta_neg:
                model.beta_neg = new_beta_neg
                print("beta negative changed to:", new_beta_neg)
        if (epoch > cfg.TRAIN.EPOCHS_PER_LOD *
            (cfg.MODEL.LAYER_COUNT - 1)) and (epoch % 10 == 0) and (local_rank
                                                                    == 0):
            print("calculating fid...")
            fid = calc_fid_from_dataset_generate(
                cfg,
                dataset,
                model_s,
                batch_size=50,
                cuda=1,
                dims=2048,
                device=torch.device(
                    "cuda:0" if torch.cuda.is_available() else "cpu"),
                num_images=50000)
            print("epoch: {}, fid: {}".format(epoch, fid))
            if best_fid is None:
                best_fid = fid
            elif fid < best_fid:
                print("best fid updated: {} -> {}".format(best_fid, fid))
                best_fid = fid
                checkpointer.save("model_tmp_lod{}_fid_{}".format(
                    lod2batch.lod, fid))
        lod2batch.set_epoch(epoch, [encoder_optimizer, decoder_optimizer])

        logger.info(
            "Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d"
            % (lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(),
               lod2batch.lod, 2**lod2batch.get_lod_power2(), 2**
               lod2batch.get_lod_power2(), lod2batch.get_blend_factor(),
               len(dataset) * world_size))

        dataset.reset(lod2batch.get_lod_power2(),
                      lod2batch.get_per_GPU_batch_size())
        batches = make_dataloader(cfg, logger, dataset,
                                  lod2batch.get_per_GPU_batch_size(),
                                  local_rank)

        scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod)

        model.train()

        need_permute = False
        epoch_start_time = time.time()

        i = 0

        diff_kls = []
        batch_kls_real = []
        batch_kls_fake = []
        batch_rec_errs = []

        for x_orig in tqdm(batches):
            i += 1
            with torch.no_grad():
                if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size():
                    continue
                if need_permute:
                    x_orig = x_orig.permute(0, 3, 1, 2)
                x_orig = (x_orig / 127.5 - 1.)

                blend_factor = lod2batch.get_blend_factor()

                needed_resolution = layer_to_resolution[lod2batch.lod]
                x = x_orig

                if lod2batch.in_transition:
                    needed_resolution_prev = layer_to_resolution[lod2batch.lod
                                                                 - 1]
                    x_prev = F.avg_pool2d(x_orig, 2, 2)
                    x_prev_2x = F.interpolate(x_prev, needed_resolution)
                    x = x * blend_factor + x_prev_2x * (1.0 - blend_factor)

            x.requires_grad = True
            if epoch < num_vae_epochs:
                encoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()
                loss = model(x,
                             lod2batch.lod,
                             blend_factor,
                             d_train=False,
                             e_train=False)
                tracker.update(dict(loss_e=loss))
                tracker.update(dict(loss_d=loss))
                loss.backward()
                encoder_optimizer.step()
                decoder_optimizer.step()
            else:
                # ------------- Update Encoder ------------- #
                encoder_optimizer.zero_grad()
                loss_e = model(x,
                               lod2batch.lod,
                               blend_factor,
                               d_train=False,
                               e_train=True)
                tracker.update(dict(loss_e=loss_e))
                loss_e.backward()
                encoder_optimizer.step()

                # ------------- Update Decoder ------------- #
                decoder_optimizer.zero_grad()
                loss_d = model(x,
                               lod2batch.lod,
                               blend_factor,
                               d_train=True,
                               e_train=False)
                loss_d.backward()
                tracker.update(dict(loss_d=loss_d))
                decoder_optimizer.step()

            # ------------- Update Statistics ------------- #
            if distributed:
                tracker.update(dict(rec_loss=model.module.last_rec_loss))
                tracker.update(dict(real_kl=model.module.last_real_kl))
                tracker.update(dict(fake_kl=model.module.last_fake_kl))
                tracker.update(dict(kl_diff=model.module.last_kl_diff))
                tracker.update(dict(expelbo_f=model.module.last_expelbo_fake))
                tracker.update(dict(expelbo_r=model.module.last_expelbo_rec))

                diff_kls.append(model.module.last_kl_diff.data.cpu())
                batch_kls_real.append(model.module.last_real_kl)
                batch_kls_fake.append(model.module.last_fake_kl)
                batch_rec_errs.append(model.module.last_rec_loss)
            else:
                tracker.update(dict(rec_loss=model.last_rec_loss))
                tracker.update(dict(real_kl=model.last_real_kl))
                tracker.update(dict(fake_kl=model.last_fake_kl))
                tracker.update(dict(kl_diff=model.last_kl_diff))
                tracker.update(dict(expelbo_f=model.last_expelbo_fake))
                tracker.update(dict(expelbo_r=model.last_expelbo_rec))

                diff_kls.append(model.last_kl_diff.data.cpu())
                batch_kls_real.append(model.last_real_kl)
                batch_kls_fake.append(model.last_fake_kl)
                batch_rec_errs.append(model.last_rec_loss)

            if local_rank == 0:
                betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0))
                model_s.lerp(model, betta)

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time

            lod_for_saving_model = lod2batch.lod
            lod2batch.step()
            if local_rank == 0:
                if lod2batch.is_time_to_save():
                    checkpointer.save("model_tmp_intermediate_lod%d" %
                                      lod_for_saving_model)
                if lod2batch.is_time_to_report():
                    save_sample(lod2batch, tracker, sample, samplez, x, logger,
                                model_s, cfg, encoder_optimizer,
                                decoder_optimizer)

        scheduler.step()
        mean_diff_kl = np.mean(diff_kls)
        print("mean diff kl: ", mean_diff_kl)

        if epoch > num_vae_epochs - 1:
            kls_real.append(np.mean(batch_kls_real))
            kls_fake.append(np.mean(batch_kls_fake))
            rec_errs.append(np.mean(batch_rec_errs))

        if local_rank == 0:
            checkpointer.save("model_tmp_lod%d" % lod_for_saving_model)
            save_sample(lod2batch, tracker, sample, samplez, x, logger,
                        model_s, cfg, encoder_optimizer, decoder_optimizer)

    logger.info("Training finish!... save training results")
    if local_rank == 0:
        checkpointer.save("model_final").wait()
Example #4
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = SoftIntroVAEModelTL(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
        style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER,
        beta_kl=cfg.MODEL.BETA_KL,
        beta_rec=cfg.MODEL.BETA_REC,
        beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1],
        scale=cfg.MODEL.SCALE
    )
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters decoder:")
    print(count_parameters(decoder))

    logger.info("Trainable parameters encoder:")
    print(count_parameters(encoder))

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                {},
                                logger=logger,
                                save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        z, mu, _ = model.encode(x, layer_count - 1, 1)
        styles = model.mapping_fl(mu)
        return styles

    def decode(x):
        return model.decoder(x, layer_count - 1, 1, noise=True)

    rnd = np.random.RandomState(5)

    dataset = TFRecordsDataset(cfg, logger, rank=0, world_size=1, buffer_size_mb=10, channels=cfg.MODEL.CHANNELS,
                               train=False)

    dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 10)
    b = iter(make_dataloader(cfg, logger, dataset, 10, 0, numpy=True))

    def make(sample):
        canvas = []
        with torch.no_grad():
            for img in sample:
                x = torch.tensor(np.asarray(img, dtype=np.float32), device='cpu',
                                 requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]
                latents = encode(x[None, ...].cuda())
                f = decode(latents)
                r = torch.cat([x[None, ...].detach().cpu(), f.detach().cpu()], dim=3)
                canvas.append(r)
        return canvas

    sample = next(b)
    canvas = make(sample)
    canvas = torch.cat(canvas, dim=0)

    save_image(canvas * 0.5 + 0.5, './make_figures/reconstructions_ffhq_real_1.png', nrow=2, pad_value=1.0)

    sample = next(b)
    canvas = make(sample)
    canvas = torch.cat(canvas, dim=0)

    save_image(canvas * 0.5 + 0.5, './make_figures/reconstructions_ffhq_real_2.png', nrow=2, pad_value=1.0)
Example #5
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = SoftIntroVAEModelTL(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
        style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER,
        beta_kl=cfg.MODEL.BETA_KL,
        beta_rec=cfg.MODEL.BETA_REC,
        beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1],
        scale=cfg.MODEL.SCALE)

    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl

    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters decoder:")
    print(count_parameters(decoder))

    logger.info("Trainable parameters encoder:")
    print(count_parameters(encoder))

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    checkpointer.load()

    model.eval()

    im_size = 2**(cfg.MODEL.LAYER_COUNT + 1)
    seed = np.random.randint(0, 999999)
    print("seed:", seed)
    with torch.no_grad():
        path = './make_figures/output'
        os.makedirs(path, exist_ok=True)
        os.makedirs(os.path.join(path, cfg.NAME), exist_ok=True)
        draw_uncurated_result_figure(
            cfg,
            './make_figures/output/%s/generations.jpg' % cfg.NAME,
            model,
            cx=0,
            cy=0,
            cw=im_size,
            ch=im_size,
            rows=6,
            lods=[0, 0, 0, 1, 1, 2],
            seed=seed)
Example #6
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = SoftIntroVAEModelTL(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
        style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER,
        beta_kl=cfg.MODEL.BETA_KL,
        beta_rec=cfg.MODEL.BETA_REC,
        beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1],
        scale=cfg.MODEL.SCALE)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters decoder:")
    print(count_parameters(decoder))

    logger.info("Trainable parameters encoder:")
    print(count_parameters(encoder))

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        z, mu, _ = model.encode(x, layer_count - 1, 1)
        styles = model.mapping_fl(mu)
        return styles

    def decode(x):
        return model.decoder(x, layer_count - 1, 1, noise=True)

    rnd = np.random.RandomState(4)

    path = cfg.DATASET.SAMPLES_PATH
    im_size = 2**(cfg.MODEL.LAYER_COUNT + 1)

    pathA = '17460.jpg'
    pathB = '02973.jpg'

    def open_image(filename):
        img = np.asarray(Image.open(path + '/' + filename))
        if img.shape[2] == 4:
            img = img[:, :, :3]
        im = img.transpose((2, 0, 1))
        x = torch.tensor(np.asarray(im, dtype=np.float32),
                         device='cpu',
                         requires_grad=True).cuda() / 127.5 - 1.
        if x.shape[0] == 4:
            x = x[:3]
        factor = x.shape[2] // im_size
        if factor != 1:
            x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
        assert x.shape[2] == im_size
        _latents = encode(x[None, ...].cuda())
        latents = _latents[0, 0]
        return latents

    def make(w):
        with torch.no_grad():
            w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1)
            x_rec = decode(w)
            return x_rec

    wa = open_image(pathA)
    wb = open_image(pathB)

    width = 7

    images = []

    for j in range(width):
        kh = j / (width - 1.0)

        ka = (1.0 - kh)
        kb = kh
        w = ka * wa + kb * wb

        interpolated = make(w)
        images.append(interpolated)

    images = torch.cat(images)

    path = './make_figures/output'
    os.makedirs(path, exist_ok=True)
    os.makedirs(os.path.join(path, cfg.NAME), exist_ok=True)
    save_image(images * 0.5 + 0.5,
               './make_figures/output/%s/interpolations.png' % cfg.NAME,
               nrow=width)
    save_image(images * 0.5 + 0.5,
               './make_figures/output/%s/interpolations.jpg' % cfg.NAME,
               nrow=width)
Example #7
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = SoftIntroVAEModelTL(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
        style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER,
        beta_kl=cfg.MODEL.BETA_KL,
        beta_rec=cfg.MODEL.BETA_REC,
        beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1],
        scale=cfg.MODEL.SCALE)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters decoder:")
    print(count_parameters(decoder))

    logger.info("Trainable parameters encoder:")
    print(count_parameters(encoder))

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        z, mu, _ = model.encode(x, layer_count - 1, 1)
        styles = model.mapping_fl(mu)
        return styles

    def decode(x):
        return model.decoder(x, layer_count - 1, 1, noise=True)

    path = cfg.DATASET.SAMPLES_PATH

    im_size = 2**(cfg.MODEL.LAYER_COUNT + 1)

    paths = list(os.listdir(path))

    paths = sorted(paths)
    random.seed(5)
    random.shuffle(paths)

    def move_to(list, item, new_index):
        list.remove(item)
        list.insert(new_index, item)

    def make(paths):
        src = []
        for filename in paths:
            img = np.asarray(Image.open(path + '/' + filename))
            if img.shape[2] == 4:
                img = img[:, :, :3]
            im = img.transpose((2, 0, 1))
            x = torch.tensor(np.asarray(im, dtype=np.float32),
                             requires_grad=True).cuda() / 127.5 - 1.
            if x.shape[0] == 4:
                x = x[:3]
            factor = x.shape[2] // im_size
            if factor != 1:
                x = torch.nn.functional.avg_pool2d(x[None, ...], factor,
                                                   factor)[0]
            assert x.shape[2] == im_size
            src.append(x)

        with torch.no_grad():
            reconstructions = []
            for s in src:
                latents = encode(s[None, ...])
                reconstructions.append(decode(latents).cpu().detach().numpy())
        return src, reconstructions

    def chunker_list(seq, size):
        return list((seq[i::size] for i in range(size)))

    final = chunker_list(paths, 4)
    path0, path1, path2, path3 = final

    path0.reverse()
    path1.reverse()
    path2.reverse()
    path3.reverse()

    src0, rec0 = make(path0)
    src1, rec1 = make(path1)
    src2, rec2 = make(path2)
    src3, rec3 = make(path3)

    initial_resolution = im_size

    lods_down = 1
    padding_step = 4

    width = 0
    height = 0

    current_padding = 0

    final_resolution = initial_resolution
    for _ in range(lods_down):
        final_resolution /= 2

    for i in range(lods_down + 1):
        width += current_padding * 2**(lods_down - i)
        height += current_padding * 2**(lods_down - i)
        current_padding += padding_step

    width += 2**(lods_down + 1) * final_resolution
    height += (lods_down + 1) * initial_resolution

    width = int(width)
    height = int(height)

    def make_part(current_padding, src, rec):
        canvas = np.ones([3, height + 20, width + 10])

        padd = 0

        initial_padding = current_padding

        height_padding = 0

        for i in range(lods_down + 1):
            for x in range(2**i):
                for y in range(2**i):
                    try:
                        ims = src.pop()
                        imr = rec.pop()[0]
                        ims = ims.cpu().detach().numpy()
                        imr = imr

                        res = int(initial_resolution / 2**i)

                        ims = resize(ims, (3, initial_resolution / 2**i,
                                           initial_resolution / 2**i))
                        imr = resize(imr, (3, initial_resolution / 2**i,
                                           initial_resolution / 2**i))

                        place(
                            canvas, ims,
                            current_padding + x * (2 * res + current_padding),
                            i * initial_resolution + height_padding + y *
                            (res + current_padding))

                        place(
                            canvas, imr, current_padding + res + x *
                            (2 * res + current_padding),
                            i * initial_resolution + height_padding + y *
                            (res + current_padding))

                    except IndexError:
                        return canvas

            height_padding += initial_padding * 2

            current_padding -= padding_step
            padd += padding_step
        return canvas

    canvas = [
        make_part(current_padding, src0, rec0),
        make_part(current_padding, src1, rec1),
        make_part(current_padding, src2, rec2),
        make_part(current_padding, src3, rec3)
    ]

    canvas = np.concatenate(canvas, axis=2)

    path = './make_figures/output'
    os.makedirs(path, exist_ok=True)
    os.makedirs(os.path.join(path, cfg.NAME), exist_ok=True)
    print('Saving image')
    save_path = './make_figures/output/%s/reconstructions_multiresolution.png' % cfg.NAME
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    save_image(torch.Tensor(canvas), save_path)