Esempio n. 1
0
def main(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)

    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    logger.info("Extracting attributes")

    decoder = nn.DataParallel(decoder)

    indices = [0, 1, 2, 3, 4, 10, 11, 17, 19]
    with torch.no_grad():
        p = Predictions(cfg, minibatch_gpu=4)
        for i in indices:
            p.evaluate(logger, mapping_fl, decoder,
                       cfg.DATASET.MAX_RESOLUTION_LEVEL - 2, i)
Esempio n. 2
0
def train(cfg, logger, local_rank, world_size, distributed):
    torch.cuda.set_device(local_rank)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
                  style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)
    model.cuda(local_rank)
    model.train()

    if local_rank == 0:
        model_s = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                        layer_count=cfg.MODEL.LAYER_COUNT,
                        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                        channels=cfg.MODEL.CHANNELS,
                        generator=cfg.MODEL.GENERATOR,
                        encoder=cfg.MODEL.ENCODER)
        model_s.cuda(local_rank)
        model_s.eval()
        model_s.requires_grad_(False)

    if distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            broadcast_buffers=False,
            bucket_cap_mb=25,
            find_unused_parameters=True)
        model.device_ids = None

        decoder = model.module.decoder
        encoder = model.module.encoder
        discriminator = model.module.discriminator
        mapping_fl = model.module.mapping_fl
        dlatent_avg = model.module.dlatent_avg
    else:
        decoder = model.decoder
        encoder = model.encoder
        discriminator = model.discriminator
        mapping_fl = model.mapping_fl
        dlatent_avg = model.dlatent_avg

    count_param_override.print = lambda a: logger.info(a)

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    decoder_optimizer = LREQAdam([{
        'params': decoder.parameters()
    }, {
        'params': mapping_fl.parameters()
    }],
                                 lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                 betas=(cfg.TRAIN.ADAM_BETA_0,
                                        cfg.TRAIN.ADAM_BETA_1),
                                 weight_decay=0)

    encoder_optimizer = LREQAdam([
        {
            'params': encoder.parameters()
        },
    ],
                                 lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                 betas=(cfg.TRAIN.ADAM_BETA_0,
                                        cfg.TRAIN.ADAM_BETA_1),
                                 weight_decay=0)

    discriminator_optimizer = LREQAdam([
        {
            'params': discriminator.parameters()
        },
    ],
                                       lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                       betas=(cfg.TRAIN.ADAM_BETA_0,
                                              cfg.TRAIN.ADAM_BETA_1),
                                       weight_decay=0)

    scheduler = ComboMultiStepLR(optimizers={
        'encoder_optimizer': encoder_optimizer,
        'discriminator_optimizer': discriminator_optimizer,
        'decoder_optimizer': decoder_optimizer
    },
                                 milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,
                                 gamma=cfg.TRAIN.LEARNING_DECAY_RATE,
                                 reference_batch_size=32,
                                 base_lr=cfg.TRAIN.LEARNING_RATES)

    model_dict = {
        'discriminator': discriminator,
        'encoder': encoder,
        'generator': decoder,
        'mapping_fl': mapping_fl,
        'dlatent_avg': dlatent_avg
    }
    if local_rank == 0:
        model_dict['discriminator_s'] = model_s.discriminator
        model_dict['encoder_s'] = model_s.encoder
        model_dict['generator_s'] = model_s.decoder
        model_dict['mapping_fl_s'] = model_s.mapping_fl

    tracker = LossTracker(cfg.OUTPUT_DIR)

    checkpointer = Checkpointer(
        cfg,
        model_dict, {
            'encoder_optimizer': encoder_optimizer,
            'discriminator_optimizer': discriminator_optimizer,
            'decoder_optimizer': decoder_optimizer,
            'scheduler': scheduler,
            'tracker': tracker
        },
        logger=logger,
        save=local_rank == 0)

    extra_checkpoint_data = checkpointer.load()
    logger.info("Starting from epoch: %d" % (scheduler.start_epoch()))

    arguments.update(extra_checkpoint_data)

    layer_to_resolution = decoder.layer_to_resolution

    dataset = TFRecordsDataset(cfg,
                               logger,
                               rank=local_rank,
                               world_size=world_size,
                               buffer_size_mb=1024,
                               channels=cfg.MODEL.CHANNELS)

    rnd = np.random.RandomState(3456)
    latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE)
    samplez = torch.tensor(latents).float().cuda()

    lod2batch = lod_driver.LODDriver(cfg,
                                     logger,
                                     world_size,
                                     dataset_size=len(dataset) * world_size)

    if cfg.DATASET.SAMPLES_PATH:
        path = cfg.DATASET.SAMPLES_PATH
        src = []
        with torch.no_grad():
            for filename in list(os.listdir(path))[:32]:
                img = np.asarray(Image.open(os.path.join(path, filename)))
                if img.shape[2] == 4:
                    img = img[:, :, :3]
                im = img.transpose((2, 0, 1))
                x = torch.tensor(np.asarray(im, dtype=np.float32),
                                 requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]
                src.append(x)
            sample = torch.stack(src)
    else:
        dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 32)
        sample = next(make_dataloader(cfg, logger, dataset, 32, local_rank))
        sample = (sample / 127.5 - 1.)

    lod2batch.set_epoch(scheduler.start_epoch(),
                        [encoder_optimizer, decoder_optimizer])

    for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
        model.train()
        lod2batch.set_epoch(epoch, [encoder_optimizer, decoder_optimizer])

        logger.info(
            "Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d"
            % (lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(),
               lod2batch.lod, 2**lod2batch.get_lod_power2(), 2**
               lod2batch.get_lod_power2(), lod2batch.get_blend_factor(),
               len(dataset) * world_size))

        dataset.reset(lod2batch.get_lod_power2(),
                      lod2batch.get_per_GPU_batch_size())
        batches = make_dataloader(cfg, logger, dataset,
                                  lod2batch.get_per_GPU_batch_size(),
                                  local_rank)

        scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod)

        model.train()

        need_permute = False
        epoch_start_time = time.time()

        i = 0
        with torch.autograd.profiler.profile(use_cuda=True,
                                             enabled=False) as prof:
            for x_orig in tqdm(batches):
                i += 1
                with torch.no_grad():
                    if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size():
                        continue
                    if need_permute:
                        x_orig = x_orig.permute(0, 3, 1, 2)
                    x_orig = (x_orig / 127.5 - 1.)

                    blend_factor = lod2batch.get_blend_factor()

                    needed_resolution = layer_to_resolution[lod2batch.lod]
                    x = x_orig

                    if lod2batch.in_transition:
                        needed_resolution_prev = layer_to_resolution[
                            lod2batch.lod - 1]
                        x_prev = F.avg_pool2d(x_orig, 2, 2)
                        x_prev_2x = F.interpolate(x_prev, needed_resolution)
                        x = x * blend_factor + x_prev_2x * (1.0 - blend_factor)

                x.requires_grad = True

                loss_d = model(x,
                               lod2batch.lod,
                               blend_factor,
                               d_train=True,
                               ae=False)
                tracker.update(dict(loss_d=loss_d))
                loss_d.backward()
                discriminator_optimizer.step()
                decoder_optimizer.zero_grad()
                discriminator_optimizer.zero_grad()

                loss_g = model(x,
                               lod2batch.lod,
                               blend_factor,
                               d_train=False,
                               ae=False)
                tracker.update(dict(loss_g=loss_g))
                loss_g.backward()
                decoder_optimizer.step()
                decoder_optimizer.zero_grad()
                discriminator_optimizer.zero_grad()

                lae = model(x,
                            lod2batch.lod,
                            blend_factor,
                            d_train=True,
                            ae=True)
                tracker.update(dict(lae=lae))
                (lae).backward()
                encoder_optimizer.step()
                decoder_optimizer.step()
                encoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()

                if local_rank == 0:
                    betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0))
                    model_s.lerp(model, betta)

                epoch_end_time = time.time()
                per_epoch_ptime = epoch_end_time - epoch_start_time

                lod_for_saving_model = lod2batch.lod
                lod2batch.step()
                if local_rank == 0:
                    if lod2batch.is_time_to_save():
                        checkpointer.save("model_tmp_intermediate_lod%d" %
                                          lod_for_saving_model)
                    if lod2batch.is_time_to_report():
                        save_sample(lod2batch, tracker, sample, samplez, x,
                                    logger, model_s, cfg, encoder_optimizer,
                                    decoder_optimizer)

        scheduler.step()

        if local_rank == 0:
            checkpointer.save("model_tmp_lod%d" % lod_for_saving_model)
            save_sample(lod2batch, tracker, sample, samplez, x, logger,
                        model_s, cfg, encoder_optimizer, decoder_optimizer)

    logger.info("Training finish!... save training results")
    if local_rank == 0:
        checkpointer.save("model_final").wait()
Esempio n. 3
0
def _main(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()
    last_epoch = list(extra_checkpoint_data['auxiliary']
                      ['scheduler'].values())[0]['last_epoch']
    logger.info("Model trained for %d epochs" % last_epoch)

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        layer_count = cfg.MODEL.LAYER_COUNT

        zlist = []
        for i in range(x.shape[0]):
            Z, _ = model.encode(x[i][None, ...], layer_count - 1, 1)
            zlist.append(Z)
        Z = torch.cat(zlist)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        decoded = []
        for i in range(x.shape[0]):
            r = model.decoder(x[i][None, ...], layer_count - 1, 1, noise=True)
            decoded.append(r)
        return torch.cat(decoded)

    path = cfg.DATASET.STYLE_MIX_PATH
    im_size = 2**(cfg.MODEL.LAYER_COUNT + 1)

    src_originals = []
    for i in range(src_len):
        try:
            im = np.asarray(Image.open(os.path.join(path, 'src/%d.png' % i)))
        except FileNotFoundError:
            im = np.asarray(Image.open(os.path.join(path, 'src/%d.jpg' % i)))
        im = im.transpose((2, 0, 1))
        x = torch.tensor(np.asarray(im, dtype=np.float32),
                         requires_grad=True).cuda() / 127.5 - 1.
        if x.shape[0] == 4:
            x = x[:3]
        factor = x.shape[2] // im_size
        if factor != 1:
            x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
        assert x.shape[2] == im_size
        src_originals.append(x)
    src_originals = torch.stack([x for x in src_originals])
    dst_originals = []
    for i in range(dst_len):
        try:
            im = np.asarray(Image.open(os.path.join(path, 'dst/%d.png' % i)))
        except FileNotFoundError:
            im = np.asarray(Image.open(os.path.join(path, 'dst/%d.jpg' % i)))
        im = im.transpose((2, 0, 1))
        x = torch.tensor(np.asarray(im, dtype=np.float32),
                         requires_grad=True).cuda() / 127.5 - 1.
        if x.shape[0] == 4:
            x = x[:3]
        factor = x.shape[2] // im_size
        if factor != 1:
            x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
        assert x.shape[2] == im_size
        dst_originals.append(x)
    dst_originals = torch.stack([x for x in dst_originals])

    src_latents = encode(src_originals)
    src_images = decode(src_latents)

    dst_latents = encode(dst_originals)
    dst_images = decode(dst_latents)

    canvas = np.zeros([3, im_size * (dst_len + 1), im_size * (src_len + 1)])

    os.makedirs('style_mixing/output/%s/' % cfg.NAME, exist_ok=True)

    for i in range(src_len):
        save_image(src_originals[i] * 0.5 + 0.5,
                   'style_mixing/output/%s/source_%d.png' % (cfg.NAME, i))
        place(canvas, src_originals[i], 1 + i, 0)

    for i in range(dst_len):
        save_image(dst_originals[i] * 0.5 + 0.5,
                   'style_mixing/output/%s/dst_coarse_%d.png' % (cfg.NAME, i))
        place(canvas, dst_originals[i], 0, 1 + i)

    style_ranges = [range(0, 4)] * 3 + [range(4, 8)] * 2 + [
        range(8, layer_count * 2)
    ]

    def mix_styles(style_src, style_dst, r):
        style = style_dst.clone()
        style[:, r] = style_src[:, r]
        return style

    for row in range(dst_len):
        row_latents = torch.stack([dst_latents[row]] * src_len)
        style = mix_styles(src_latents, row_latents, style_ranges[row])
        rec = model.decoder(style, layer_count - 1, 1, noise=True)
        for j in range(rec.shape[0]):
            save_image(
                rec[j] * 0.5 + 0.5,
                'style_mixing/output/%s/rec_coarse_%d_%d.png' %
                (cfg.NAME, row, j))
            place(canvas, rec[j], 1 + j, 1 + row)

    save_image(torch.Tensor(canvas),
               'style_mixing/output/%s/stylemix.png' % cfg.NAME)
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :,
                                                            np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    rnd = np.random.RandomState(4)
    latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)

    path = cfg.DATASET.SAMPLES_PATH
    im_size = 2**(cfg.MODEL.LAYER_COUNT + 1)

    pathA = '00001.png'
    pathB = '00022.png'
    pathC = '00077.png'
    pathD = '00016.png'

    def open_image(filename):
        img = np.asarray(Image.open(path + '/' + filename))
        if img.shape[2] == 4:
            img = img[:, :, :3]
        im = img.transpose((2, 0, 1))
        x = torch.tensor(np.asarray(im, dtype=np.float32),
                         device='cpu',
                         requires_grad=True).cuda() / 127.5 - 1.
        if x.shape[0] == 4:
            x = x[:3]
        factor = x.shape[2] // im_size
        if factor != 1:
            x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
        assert x.shape[2] == im_size
        _latents = encode(x[None, ...].cuda())
        latents = _latents[0, 0]
        return latents

    def make(w):
        with torch.no_grad():
            w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1)
            x_rec = decode(w)
            return x_rec

    wa = open_image(pathA)
    wb = open_image(pathB)
    wc = open_image(pathC)
    wd = open_image(pathD)

    height = 7
    width = 7

    images = []

    for i in range(height):
        for j in range(width):
            kv = i / (height - 1.0)
            kh = j / (width - 1.0)

            ka = (1.0 - kh) * (1.0 - kv)
            kb = kh * (1.0 - kv)
            kc = (1.0 - kh) * kv
            kd = kh * kv

            w = ka * wa + kb * wb + kc * wc + kd * wd

            interpolated = make(w)
            images.append(interpolated)

    images = torch.cat(images)

    save_image(images * 0.5 + 0.5,
               'make_figures/output/%s/interpolations.png' % cfg.NAME,
               nrow=width)
    save_image(images * 0.5 + 0.5,
               'make_figures/output/%s/interpolations.jpg' % cfg.NAME,
               nrow=width)
Esempio n. 5
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=None,
                  truncation_cutoff=None,
                  style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)

    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder

    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg_s': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()
    last_epoch = list(extra_checkpoint_data['auxiliary']
                      ['scheduler'].values())[0]['last_epoch']
    logger.info("Model trained for %d epochs" % last_epoch)

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    logger.info("Evaluating FID metric")

    model.decoder = nn.DataParallel(decoder)

    with torch.no_grad():
        ppl = FID(cfg,
                  num_images=50000,
                  minibatch_size=16 * torch.cuda.device_count())
        ppl.evaluate(logger, mapping_fl, model.decoder, model,
                     cfg.DATASET.MAX_RESOLUTION_LEVEL - 2)
Esempio n. 6
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                {},
                                logger=logger,
                                save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :, np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    rnd = np.random.RandomState(5)
    latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)

    path = cfg.DATASET.SAMPLES_PATH

    paths = list(os.listdir(path))

    paths = sorted(paths)
    random.seed(3456)
    random.shuffle(paths)

    def make(paths):
        canvas = []
        with torch.no_grad():
            for filename in paths:
                img = np.asarray(Image.open(path + '/' + filename))
                if img.shape[2] == 4:
                    img = img[:, :, :3]
                im = img.transpose((2, 0, 1))
                x = torch.tensor(np.asarray(im, dtype=np.float32), device='cpu', requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]
                latents = encode(x[None, ...].cuda())
                f = decode(latents)
                r = torch.cat([x[None, ...].detach().cpu(), f.detach().cpu()], dim=3)
                canvas.append(r)
        return canvas

    canvas = make(paths[:40])
    canvas = torch.cat(canvas, dim=0)

    save_image(canvas * 0.5 + 0.5, 'make_figures/output/reconstructions_bed_1.png', nrow=4, pad_value=1.0)

    canvas = make(paths[40:80])
    canvas = torch.cat(canvas, dim=0)

    save_image(canvas * 0.5 + 0.5, 'make_figures/output/reconstructions_bed_2.png', nrow=4, pad_value=1.0)
Esempio n. 7
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
        style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER)

    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl

    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                {},
                                logger=logger,
                                save=False)

    checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    decoder = nn.DataParallel(decoder)

    im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1)
    with torch.no_grad():
        draw_uncurated_result_figure(cfg, 'make_figures/output/%s/generations.jpg' % cfg.NAME,
                                     model, cx=0, cy=0, cw=im_size, ch=im_size, rows=6, lods=[0, 0, 0, 1, 1, 2], seed=5)
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :,
                                                            np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    rnd = np.random.RandomState(5)
    latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)

    dataset = TFRecordsDataset(cfg,
                               logger,
                               rank=0,
                               world_size=1,
                               buffer_size_mb=10,
                               channels=cfg.MODEL.CHANNELS,
                               train=False)

    dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 10)
    b = iter(make_dataloader(cfg, logger, dataset, 10, 0, numpy=True))

    def make(sample):
        canvas = []
        with torch.no_grad():
            for img in sample:
                x = torch.tensor(np.asarray(img, dtype=np.float32),
                                 device='cpu',
                                 requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]
                latents = encode(x[None, ...].cuda())
                f = decode(latents)
                r = torch.cat([x[None, ...].detach().cpu(),
                               f.detach().cpu()],
                              dim=3)
                canvas.append(r)
        return canvas

    sample = next(b)
    canvas = make(sample)
    canvas = torch.cat(canvas, dim=0)

    save_image(canvas * 0.5 + 0.5,
               'make_figures/reconstructions_ffhq_real_1.png',
               nrow=2,
               pad_value=1.0)

    sample = next(b)
    canvas = make(sample)
    canvas = torch.cat(canvas, dim=0)

    save_image(canvas * 0.5 + 0.5,
               'make_figures/reconstructions_ffhq_real_2.png',
               nrow=2,
               pad_value=1.0)
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :,
                                                            np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    path = 'dataset_samples/faces/pioneer256x256'

    paths = list(os.listdir(path))

    def make(paths):
        with torch.no_grad():
            for filename in paths:
                img = np.asarray(Image.open(path + '/' + filename))
                if img.shape[2] == 4:
                    img = img[:, :, :3]
                im = img.transpose((2, 0, 1))

                x = torch.tensor(np.asarray(im, dtype=np.float32),
                                 device='cpu',
                                 requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]

                while x.shape[2] != model.decoder.layer_to_resolution[6]:
                    x = F.avg_pool2d(x, 2, 2)

                latents = encode(x[None, ...].cuda())
                f = decode(latents)
                r = torch.cat([x[None, ...].detach().cpu(),
                               f.detach().cpu()],
                              dim=3)
                os.makedirs('make_figures/output/pioneer/', exist_ok=True)
                save_image(f.detach().cpu() * 0.5 + 0.5,
                           'make_figures/output/pioneer/%s_alae.png' %
                           filename[:-9],
                           nrow=1,
                           pad_value=1.0)

    make(paths)
Esempio n. 10
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :,
                                                            np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    path = cfg.DATASET.SAMPLES_PATH
    im_size = 2**(cfg.MODEL.LAYER_COUNT + 1)

    def do_attribute_traversal(path, attrib_idx, start, end):
        img = np.asarray(Image.open(path))
        if img.shape[2] == 4:
            img = img[:, :, :3]
        im = img.transpose((2, 0, 1))
        x = torch.tensor(np.asarray(im, dtype=np.float32),
                         device='cpu',
                         requires_grad=True).cuda() / 127.5 - 1.
        if x.shape[0] == 4:
            x = x[:3]
        factor = x.shape[2] // im_size
        if factor != 1:
            x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
        assert x.shape[2] == im_size
        _latents = encode(x[None, ...].cuda())
        latents = _latents[0, 0]

        latents -= model.dlatent_avg.buff.data[0]

        w0 = torch.tensor(np.load("principal_directions/direction_%d.npy" %
                                  attrib_idx),
                          dtype=torch.float32)

        attr0 = (latents * w0).sum()

        latents = latents - attr0 * w0

        def update_image(w):
            with torch.no_grad():
                w = w + model.dlatent_avg.buff.data[0]
                w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers,
                                              1)

                layer_idx = torch.arange(
                    model.mapping_fl.num_layers)[np.newaxis, :, np.newaxis]
                cur_layers = (7 + 1) * 2
                mixing_cutoff = cur_layers
                styles = torch.where(layer_idx < mixing_cutoff, w, _latents[0])

                x_rec = decode(styles)
                return x_rec

        traversal = []

        r = 7
        inc = (end - start) / (r - 1)

        for i in range(r):
            W = latents + w0 * (attr0 + start)
            im = update_image(W)

            traversal.append(im)
            attr0 += inc
        res = torch.cat(traversal)

        indices = [0, 1, 2, 3, 4, 10, 11, 17, 19]
        labels = [
            "gender",
            "smile",
            "attractive",
            "wavy-hair",
            "young",
            "big_lips",
            "big_nose",
            "chubby",
            "glasses",
        ]
        save_image(res * 0.5 + 0.5,
                   "make_figures/output/%s/traversal_%s.jpg" %
                   (cfg.NAME, labels[indices.index(attrib_idx)]),
                   pad_value=1)

    do_attribute_traversal(path + '/00049.png', 0, 0.6, -34)
    do_attribute_traversal(path + '/00125.png', 1, -3, 15.0)
    do_attribute_traversal(path + '/00057.png', 3, -2, 30.0)
    do_attribute_traversal(path + '/00031.png', 4, -10, 30.0)
    do_attribute_traversal(path + '/00088.png', 10, -0.3, 30.0)
    do_attribute_traversal(path + '/00004.png', 11, -25, 20.0)
    do_attribute_traversal(path + '/00012.png', 17, -40, 40.0)
    do_attribute_traversal(path + '/00017.png', 19, 0, 30.0)
Esempio n. 11
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)

    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :,
                                                            np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, 1.2 * ones,
                            ones)
        x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    logger.info("Evaluating PPL metric")

    decoder = nn.DataParallel(decoder)

    with torch.no_grad():
        ppl = PPL(cfg,
                  num_samples=50000,
                  epsilon=1e-4,
                  space='w',
                  sampling='full',
                  minibatch_size=16 * torch.cuda.device_count())
        ppl.evaluate(logger,
                     mapping_fl,
                     decoder,
                     cfg.DATASET.MAX_RESOLUTION_LEVEL - 2,
                     celeba_style=cfg.PPL_CELEBA_ADJUSTMENT)

    with torch.no_grad():
        ppl = PPL(cfg,
                  num_samples=50000,
                  epsilon=1e-4,
                  space='w',
                  sampling='end',
                  minibatch_size=16 * torch.cuda.device_count())
        ppl.evaluate(logger,
                     mapping_fl,
                     decoder,
                     cfg.DATASET.MAX_RESOLUTION_LEVEL - 2,
                     celeba_style=cfg.PPL_CELEBA_ADJUSTMENT)
Esempio n. 12
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                {},
                                logger=logger,
                                save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :, np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, 1.0 * ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    path = cfg.DATASET.SAMPLES_PATH
    # path = 'dataset_samples/faces/realign1024x1024_paper'

    im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1)

    paths = list(os.listdir(path))

    paths = sorted(paths)
    random.seed(5)
    random.shuffle(paths)

    def move_to(list, item, new_index):
        list.remove(item)
        list.insert(new_index, item)

    # move_to(paths, '00026.png', 0)
    # move_to(paths, '00074.png', 1)
    # move_to(paths, '00134.png', 2)
    # move_to(paths, '00036.png', 3)

    def make(paths):
        src = []
        for filename in paths:
            img = np.asarray(Image.open(path + '/' + filename))
            if img.shape[2] == 4:
                img = img[:, :, :3]
            im = img.transpose((2, 0, 1))
            x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=True).cuda() / 127.5 - 1.
            if x.shape[0] == 4:
                x = x[:3]
            factor = x.shape[2] // im_size
            if factor != 1:
                x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
            assert x.shape[2] == im_size
            src.append(x)

        with torch.no_grad():
            reconstructions = []
            for s in src:
                latents = encode(s[None, ...])
                reconstructions.append(decode(latents).cpu().detach().numpy())
        return src, reconstructions

    def chunker_list(seq, size):
        return list((seq[i::size] for i in range(size)))

    final = chunker_list(paths, 4)
    path0, path1, path2, path3 = final

    path0.reverse()
    path1.reverse()
    path2.reverse()
    path3.reverse()

    src0, rec0 = make(path0)
    src1, rec1 = make(path1)
    src2, rec2 = make(path2)
    src3, rec3 = make(path3)

    initial_resolution = im_size

    lods_down = 1
    padding_step = 4

    width = 0
    height = 0

    current_padding = 0

    final_resolution = initial_resolution
    for _ in range(lods_down):
        final_resolution /= 2

    for i in range(lods_down + 1):
        width += current_padding * 2 ** (lods_down - i)
        height += current_padding * 2 ** (lods_down - i)
        current_padding += padding_step

    width += 2 ** (lods_down + 1) * final_resolution
    height += (lods_down + 1) * initial_resolution

    width = int(width)
    height = int(height)

    def make_part(current_padding, src, rec):
        canvas = np.ones([3, height + 20, width + 10])

        padd = 0

        initial_padding = current_padding

        height_padding = 0

        for i in range(lods_down + 1):
            for x in range(2 ** i):
                for y in range(2 ** i):
                    try:
                        ims = src.pop()
                        imr = rec.pop()[0]
                        ims = ims.cpu().detach().numpy()
                        imr = imr

                        res = int(initial_resolution / 2 ** i)

                        ims = resize(ims, (3, initial_resolution / 2 ** i, initial_resolution / 2 ** i))
                        imr = resize(imr, (3, initial_resolution / 2 ** i, initial_resolution / 2 ** i))

                        place(canvas, ims,
                              current_padding + x * (2 * res + current_padding),
                              i * initial_resolution + height_padding + y * (res + current_padding))

                        place(canvas, imr,
                              current_padding + res + x * (2 * res + current_padding),
                              i * initial_resolution + height_padding + y * (res + current_padding))

                    except IndexError:
                        return canvas

            height_padding += initial_padding * 2

            current_padding -= padding_step
            padd += padding_step
        return canvas

    canvas = [make_part(current_padding, src0, rec0), make_part(current_padding, src1, rec1),
              make_part(current_padding, src2, rec2), make_part(current_padding, src3, rec3)]

    canvas = np.concatenate(canvas, axis=2)

    print('Saving image')
    save_path = 'make_figures/output/%s/reconstructions_multiresolution.png' % cfg.NAME
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    save_image(torch.Tensor(canvas), save_path)
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                {},
                                logger=logger,
                                save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * layer_count)[np.newaxis, :, np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    path = 'dataset_samples/faces/realign1024x1024'

    paths = list(os.listdir(path))
    paths.sort()
    paths_backup = paths[:]
    randomize = bimpy.Bool(True)
    current_file = bimpy.String("")

    ctx = bimpy.Context()

    attribute_values = [bimpy.Float(0) for i in indices]

    W = [torch.tensor(np.load("principal_directions/direction_%d.npy" % i), dtype=torch.float32) for i in indices]

    rnd = np.random.RandomState(5)

    def loadNext():
        img = np.asarray(Image.open(path + '/' + paths[0]))
        current_file.value = paths[0]
        paths.pop(0)
        if len(paths) == 0:
            paths.extend(paths_backup)

        if img.shape[2] == 4:
            img = img[:, :, :3]
        im = img.transpose((2, 0, 1))
        x = torch.tensor(np.asarray(im, dtype=np.float32), device='cpu', requires_grad=True).cuda() / 127.5 - 1.
        if x.shape[0] == 4:
            x = x[:3]

        needed_resolution = model.decoder.layer_to_resolution[-1]
        while x.shape[2] > needed_resolution:
            x = F.avg_pool2d(x, 2, 2)
        if x.shape[2] != needed_resolution:
            x = F.adaptive_avg_pool2d(x, (needed_resolution, needed_resolution))

        img_src = ((x * 0.5 + 0.5) * 255).type(torch.long).clamp(0, 255).cpu().type(torch.uint8).transpose(0, 2).transpose(0, 1).numpy()

        latents_original = encode(x[None, ...].cuda())
        latents = latents_original[0, 0].clone()
        latents -= model.dlatent_avg.buff.data[0]

        for v, w in zip(attribute_values, W):
            v.value = (latents * w).sum()

        for v, w in zip(attribute_values, W):
            latents = latents - v.value * w

        return latents, latents_original, img_src

    def loadRandom():
        latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)
        lat = torch.tensor(latents).float().cuda()
        dlat = mapping_fl(lat)
        layer_idx = torch.arange(2 * layer_count)[np.newaxis, :, np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        dlat = torch.lerp(model.dlatent_avg.buff.data, dlat, coefs)
        x = decode(dlat)[0]
        img_src = ((x * 0.5 + 0.5) * 255).type(torch.long).clamp(0, 255).cpu().type(torch.uint8).transpose(0, 2).transpose(0, 1).numpy()
        latents_original = dlat
        latents = latents_original[0, 0].clone()
        latents -= model.dlatent_avg.buff.data[0]

        for v, w in zip(attribute_values, W):
            v.value = (latents * w).sum()

        for v, w in zip(attribute_values, W):
            latents = latents - v.value * w

        return latents, latents_original, img_src

    latents, latents_original, img_src = loadNext()

    ctx.init(1800, 1600, "Styles")

    def update_image(w, latents_original):
        with torch.no_grad():
            w = w + model.dlatent_avg.buff.data[0]
            w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1)

            layer_idx = torch.arange(model.mapping_fl.num_layers)[np.newaxis, :, np.newaxis]
            cur_layers = (7 + 1) * 2
            mixing_cutoff = cur_layers
            styles = torch.where(layer_idx < mixing_cutoff, w, latents_original)

            x_rec = decode(styles)
            resultsample = ((x_rec * 0.5 + 0.5) * 255).type(torch.long).clamp(0, 255)
            resultsample = resultsample.cpu()[0, :, :, :]
            return resultsample.type(torch.uint8).transpose(0, 2).transpose(0, 1)

    im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1)
    im = update_image(latents, latents_original)
    print(im.shape)
    im = bimpy.Image(im)

    display_original = True

    seed = 0

    while not ctx.should_close():
        with ctx:
            new_latents = latents + sum([v.value * w for v, w in zip(attribute_values, W)])

            if display_original:
                im = bimpy.Image(img_src)
            else:
                im = bimpy.Image(update_image(new_latents, latents_original))

            bimpy.begin("Principal directions")
            bimpy.columns(2)
            bimpy.set_column_width(0, im_size + 20)
            bimpy.image(im)
            bimpy.next_column()

            for v, label in zip(attribute_values, labels):
                bimpy.slider_float(label, v, -40.0, 40.0)

            bimpy.checkbox("Randomize noise", randomize)

            if randomize.value:
                seed += 1

            torch.manual_seed(seed)

            if bimpy.button('Next'):
                latents, latents_original, img_src = loadNext()
                display_original = True
            if bimpy.button('Display Reconstruction'):
                display_original = False
            if bimpy.button('Generate random'):
                latents, latents_original, img_src = loadRandom()
                display_original = False

            if bimpy.input_text("Current file", current_file, 64) and os.path.exists(path + '/' + current_file.value):
                paths.insert(0, current_file.value)
                latents, latents_original, img_src = loadNext()

            bimpy.end()