Ejemplo n.º 1
0
def main(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)

    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    logger.info("Extracting attributes")

    decoder = nn.DataParallel(decoder)

    indices = [0, 1, 2, 3, 4, 10, 11, 17, 19]
    with torch.no_grad():
        p = Predictions(cfg, minibatch_gpu=4)
        for i in indices:
            p.evaluate(logger, mapping_fl, decoder,
                       cfg.DATASET.MAX_RESOLUTION_LEVEL - 2, i)
Ejemplo n.º 2
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)

    model.cuda()
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_d
    mapping_fl = model.mapping_f
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    logger.info("Generating...")

    decoder = nn.DataParallel(decoder)
    mapping_fl = nn.DataParallel(mapping_fl)

    with torch.no_grad():
        gen = ImageGenerator(cfg, num_samples=60000, minibatch_gpu=8)
        gen.evaluate(logger, mapping_fl, decoder,
                     cfg.DATASET.MAX_RESOLUTION_LEVEL - 2)
Ejemplo n.º 3
0
    def create_model(self, config_path="/configs/ffhq.yaml"):
        self.cfg = get_cfg_defaults()
        self.cfg.merge_from_file(REPO_PATH + config_path)
        logger = logging.getLogger("logger")
        logger.setLevel(logging.DEBUG)
        logging.basicConfig(
            stream=sys.stdout,
            level=logging.INFO)  # print out all the info messages

        # print("cfg", cfg)

        model = Model(startf=self.cfg.MODEL.START_CHANNEL_COUNT,
                      layer_count=self.cfg.MODEL.LAYER_COUNT,
                      maxf=self.cfg.MODEL.MAX_CHANNEL_COUNT,
                      latent_size=self.cfg.MODEL.LATENT_SPACE_SIZE,
                      truncation_psi=self.cfg.MODEL.TRUNCATIOM_PSI,
                      truncation_cutoff=self.cfg.MODEL.TRUNCATIOM_CUTOFF,
                      mapping_layers=self.cfg.MODEL.MAPPING_LAYERS,
                      channels=self.cfg.MODEL.CHANNELS,
                      generator=self.cfg.MODEL.GENERATOR,
                      encoder=self.cfg.MODEL.ENCODER)

        model.cuda()
        model.eval()
        model.requires_grad_(False)

        decoder = model.decoder
        encoder = model.encoder
        mapping_tl = model.mapping_tl
        mapping_fl = model.mapping_fl
        dlatent_avg = model.dlatent_avg

        logger.info("Trainable parameters generator:")
        print(count_parameters(decoder))
        logger.info("Trainable parameters discriminator:")
        print(count_parameters(encoder))

        model_dict = {
            'discriminator_s': encoder,
            'generator_s': decoder,
            'mapping_tl_s': mapping_tl,
            'mapping_fl_s': mapping_fl,
            'dlatent_avg': dlatent_avg
        }

        checkpointer = Checkpointer(self.cfg,
                                    model_dict, {},
                                    logger=logger,
                                    save=False)

        self.checkpointer = checkpointer
        self.model = model
        self.layer_count = self.cfg.MODEL.LAYER_COUNT
        self.encoder = encoder
        self.decoder = decoder
Ejemplo n.º 4
0
def sample(cfg, logger):
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=3)
    del model.discriminator
    model.eval()

    logger.info("Trainable parameters generator:")
    count_parameters(model.generator)

    if False:
        model_dict = {
            'generator': model.generator,
            'mapping': model.mapping,
            'dlatent_avg': model.dlatent_avg,
        }
    else:
        model_dict = {
            'generator_s': model.generator,
            'mapping_s': model.mapping,
            'dlatent_avg': model.dlatent_avg,
        }

    checkpointer = Checkpointer(cfg, model_dict, logger=logger, save=True)

    file_name = 'karras2019stylegan-ffhq'

    checkpointer.load(file_name=file_name + '.pth')

    rnd = np.random.RandomState(5)
    latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)
    sample = torch.tensor(latents).float().cuda()

    with torch.no_grad():
        model.eval()
        images = []
        for i in range(100):
            image = model.generate(model.generator.layer_count - 1,
                                   1,
                                   z=sample)

            resultsample = (image * 0.5 + 0.5)
            images.append(resultsample)

    resultsample = torch.stack(images).mean(0)

    save_image(images[0], 'test_individual.png')
    save_image(resultsample, 'test_average.png')
Ejemplo n.º 5
0
def load_ae(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)
    model.cuda()
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    path = cfg.DATASET.SAMPLES_PATH
    im_size = 2**(cfg.MODEL.LAYER_COUNT + 1)
    return model
Ejemplo n.º 6
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_d
    mapping_fl = model.mapping_f
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                {},
                                logger=logger,
                                save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_f.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :, np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    rnd = np.random.RandomState(4)
    latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)

    path = cfg.DATASET.SAMPLES_PATH
    im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1)

    pathA = '00001.png'
    pathB = '00022.png'
    pathC = '00077.png'
    pathD = '00016.png'

    def open_image(filename):
        img = np.asarray(Image.open(path + '/' + filename))
        if img.shape[2] == 4:
            img = img[:, :, :3]
        im = img.transpose((2, 0, 1))
        x = torch.tensor(np.asarray(im, dtype=np.float32), device='cpu', requires_grad=True).cuda() / 127.5 - 1.
        if x.shape[0] == 4:
            x = x[:3]
        factor = x.shape[2] // im_size
        if factor != 1:
            x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
        assert x.shape[2] == im_size
        _latents = encode(x[None, ...].cuda())
        latents = _latents[0, 0]
        return latents

    def make(w):
        with torch.no_grad():
            w = w[None, None, ...].repeat(1, model.mapping_f.num_layers, 1)
            x_rec = decode(w)
            return x_rec

    wa = open_image(pathA)
    wb = open_image(pathB)
    wc = open_image(pathC)
    wd = open_image(pathD)

    height = 7
    width = 7

    images = []

    for i in range(height):
        for j in range(width):
            kv = i / (height - 1.0)
            kh = j / (width - 1.0)

            ka = (1.0 - kh) * (1.0 - kv)
            kb = kh * (1.0 - kv)
            kc = (1.0 - kh) * kv
            kd = kh * kv

            w = ka * wa + kb * wb + kc * wc + kd * wd

            interpolated = make(w)
            images.append(interpolated)

    images = torch.cat(images)

    save_image(images * 0.5 + 0.5, 'make_figures/output/%s/interpolations.png' % cfg.NAME, nrow=width)
    save_image(images * 0.5 + 0.5, 'make_figures/output/%s/interpolations.jpg' % cfg.NAME, nrow=width)
Ejemplo n.º 7
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                {},
                                logger=logger,
                                save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :, np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    path = cfg.DATASET.SAMPLES_PATH
    im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1)

    paths = list(os.listdir(path))

    paths = sorted(paths)
    random.seed(1)
    random.shuffle(paths)

    def make(paths):
        canvas = []
        with torch.no_grad():
            for filename in paths:
                img = np.asarray(Image.open(path + '/' + filename))
                if img.shape[2] == 4:
                    img = img[:, :, :3]
                im = img.transpose((2, 0, 1))
                x = torch.tensor(np.asarray(im, dtype=np.float32), device='cpu', requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]
                factor = x.shape[2] // im_size
                if factor != 1:
                    x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
                assert x.shape[2] == im_size
                latents = encode(x[None, ...].cuda())
                f = decode(latents)
                r = torch.cat([x[None, ...].detach().cpu(), f.detach().cpu()], dim=3)
                canvas.append(r)
        return canvas

    def chunker_list(seq, n):
        return [seq[i * n:(i + 1) * n] for i in range((len(seq) + n - 1) // n)]

    paths = chunker_list(paths, 8 * 3)

    for i, chunk in enumerate(paths):
        canvas = make(chunk)
        canvas = torch.cat(canvas, dim=0)

        save_path = 'make_figures/output/%s/reconstructions_%d.png' % (cfg.NAME, i)
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        save_image(canvas * 0.5 + 0.5, save_path,
                   nrow=3,
                   pad_value=1.0)
Ejemplo n.º 8
0
def sample(cfg, logger):
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count= cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=3)
    model.eval()

    logger.info("Trainable parameters generator:")
    count_parameters(model.generator)

    model_dict = {
        'generator_s': model.generator,
        'mapping_fl_s': model.mapping,
        'dlatent_avg': model.dlatent_avg,
    }

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                logger=logger,
                                save=True)

    checkpointer.load()

    ctx = bimpy.Context()
    remove = bimpy.Bool(False)
    layers = bimpy.Int(8)

    ctx.init(1800, 1600, "Styles")

    rnd = np.random.RandomState(5)
    latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)
    sample = torch.tensor(latents).float().cuda()

    def update_image(sample):
        with torch.no_grad():
            torch.manual_seed(0)
            model.eval()
            x_rec = model.generate(layers.value, remove.value, z=sample)
            #model.generator.set(l.value, c.value)
            resultsample = ((x_rec * 0.5 + 0.5) * 255).type(torch.long).clamp(0, 255)
            resultsample = resultsample.cpu()[0, :, :, :]

            return resultsample.type(torch.uint8).transpose(0, 2).transpose(0, 1)

    with torch.no_grad():
        save_image(model.generate(8, True, z=sample) * 0.5 + 0.5, 'sample.png')

    im = bimpy.Image(update_image(sample))
    while(not ctx.should_close()):
        with ctx:

            bimpy.set_window_font_scale(2.0)

            if bimpy.checkbox('REMOVE BLOB', remove):
                im = bimpy.Image(update_image(sample))
            if bimpy.button('NEXT'):
                latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)
                sample = torch.tensor(latents).float().cuda()
                im = bimpy.Image(update_image(sample))
            if bimpy.slider_int("Layers", layers, 0, 8):
                im = bimpy.Image(update_image(sample))
            bimpy.image(im, bimpy.Vec2(1024, 1024))
Ejemplo n.º 9
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_d
    mapping_fl = model.mapping_f
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_f.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * layer_count)[np.newaxis, :, np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    path = 'dataset_samples/faces/realign1024x1024'

    paths = list(os.listdir(path))
    paths.sort()
    paths_backup = paths[:]
    randomize = bimpy.Bool(True)
    current_file = bimpy.String("")

    ctx = bimpy.Context()

    attribute_values = [bimpy.Float(0) for i in indices]

    W = [
        torch.tensor(np.load("principal_directions/direction_%d.npy" % i),
                     dtype=torch.float32) for i in indices
    ]

    rnd = np.random.RandomState(5)

    def loadNext():
        img = np.asarray(Image.open(path + '/' + paths[0]))
        current_file.value = paths[0]
        paths.pop(0)
        if len(paths) == 0:
            paths.extend(paths_backup)

        if img.shape[2] == 4:
            img = img[:, :, :3]
        im = img.transpose((2, 0, 1))
        x = torch.tensor(np.asarray(im, dtype=np.float32),
                         device='cpu',
                         requires_grad=True).cuda() / 127.5 - 1.
        if x.shape[0] == 4:
            x = x[:3]

        needed_resolution = model.decoder.layer_to_resolution[-1]
        while x.shape[2] > needed_resolution:
            x = F.avg_pool2d(x, 2, 2)
        if x.shape[2] != needed_resolution:
            x = F.adaptive_avg_pool2d(x,
                                      (needed_resolution, needed_resolution))

        img_src = ((x * 0.5 + 0.5) * 255).type(torch.long).clamp(
            0, 255).cpu().type(torch.uint8).transpose(0,
                                                      2).transpose(0,
                                                                   1).numpy()

        latents_original = encode(x[None, ...].cuda())
        latents = latents_original[0, 0].clone()
        latents -= model.dlatent_avg.buff.data[0]

        for v, w in zip(attribute_values, W):
            v.value = (latents * w).sum()

        for v, w in zip(attribute_values, W):
            latents = latents - v.value * w

        return latents, latents_original, img_src

    def loadRandom():
        latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)
        lat = torch.tensor(latents).float().cuda()
        dlat = mapping_fl(lat)
        layer_idx = torch.arange(2 * layer_count)[np.newaxis, :, np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        dlat = torch.lerp(model.dlatent_avg.buff.data, dlat, coefs)
        x = decode(dlat)[0]
        img_src = ((x * 0.5 + 0.5) * 255).type(torch.long).clamp(
            0, 255).cpu().type(torch.uint8).transpose(0,
                                                      2).transpose(0,
                                                                   1).numpy()
        latents_original = dlat
        latents = latents_original[0, 0].clone()
        latents -= model.dlatent_avg.buff.data[0]

        for v, w in zip(attribute_values, W):
            v.value = (latents * w).sum()

        for v, w in zip(attribute_values, W):
            latents = latents - v.value * w

        return latents, latents_original, img_src

    latents, latents_original, img_src = loadNext()

    ctx.init(1800, 1600, "Styles")

    def update_image(w, latents_original):
        with torch.no_grad():
            w = w + model.dlatent_avg.buff.data[0]
            w = w[None, None, ...].repeat(1, model.mapping_f.num_layers, 1)

            layer_idx = torch.arange(model.mapping_f.num_layers)[np.newaxis, :,
                                                                 np.newaxis]
            cur_layers = (7 + 1) * 2
            mixing_cutoff = cur_layers
            styles = torch.where(layer_idx < mixing_cutoff, w,
                                 latents_original)

            x_rec = decode(styles)
            resultsample = ((x_rec * 0.5 + 0.5) * 255).type(torch.long).clamp(
                0, 255)
            resultsample = resultsample.cpu()[0, :, :, :]
            return resultsample.type(torch.uint8).transpose(0,
                                                            2).transpose(0, 1)

    im_size = 2**(cfg.MODEL.LAYER_COUNT + 1)
    im = update_image(latents, latents_original)
    print(im.shape)
    im = bimpy.Image(im)

    display_original = True

    seed = 0

    while not ctx.should_close():
        with ctx:
            new_latents = latents + sum(
                [v.value * w for v, w in zip(attribute_values, W)])

            if display_original:
                im = bimpy.Image(img_src)
            else:
                im = bimpy.Image(update_image(new_latents, latents_original))

            bimpy.begin("Principal directions")
            bimpy.columns(2)
            bimpy.set_column_width(0, im_size + 20)
            bimpy.image(im)
            bimpy.next_column()

            for v, label in zip(attribute_values, labels):
                bimpy.slider_float(label, v, -40.0, 40.0)

            bimpy.checkbox("Randomize noise", randomize)

            if randomize.value:
                seed += 1

            torch.manual_seed(seed)

            if bimpy.button('Next'):
                latents, latents_original, img_src = loadNext()
                display_original = True
            if bimpy.button('Display Reconstruction'):
                display_original = False
            if bimpy.button('Generate random'):
                latents, latents_original, img_src = loadRandom()
                display_original = False

            if bimpy.input_text(
                    "Current file", current_file,
                    64) and os.path.exists(path + '/' + current_file.value):
                paths.insert(0, current_file.value)
                latents, latents_original, img_src = loadNext()

            bimpy.end()
Ejemplo n.º 10
0
def train(cfg, logger, local_rank, world_size, folding_id=0, inliner_classes=None):
    torch.cuda.set_device(local_rank)
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        channels=cfg.MODEL.INPUT_IMAGE_CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER,
    )
    model.cuda(local_rank)
    model.train()

    model_s = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        channels=cfg.MODEL.INPUT_IMAGE_CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER,
    )
    model_s.cuda(local_rank)
    model_s.eval()
    model_s.requires_grad_(False)

    generator = model.generator
    encoder = model.encoder
    discriminator = model.discriminator
    z_discriminator = model.z_discriminator

    count_param_override.print = lambda a: logger.info(a)

    logger.info("Trainable parameters generator:")
    count_parameters(generator)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    generator_optimizer = LREQAdam([
        {'params': generator.parameters()},
    ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0)

    z_discriminator_optimizer = LREQAdam([
        {'params': z_discriminator.parameters()},
    ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0)

    encoder_optimizer = LREQAdam([
        {'params': encoder.parameters()},
        {'params': discriminator.parameters()},
    ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0)

    scheduler = ComboMultiStepLR(optimizers=
                                 {
                                    'encoder_optimizer': encoder_optimizer,
                                    'generator_optimizer': generator_optimizer,
                                    'z_discriminator_optimizer': z_discriminator_optimizer
                                 },
                                 milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,
                                 gamma=cfg.TRAIN.LEARNING_DECAY_RATE,
                                 reference_batch_size=32, base_lr=cfg.TRAIN.LEARNING_RATES)

    model_dict = {
        'encoder': encoder,
        'generator': generator,
        'discriminator': discriminator,
        'z_discriminator': z_discriminator,
        'encoder_s': model_s.encoder,
        'generator_s': model_s.generator,
        'discriminator_s': model_s.discriminator,
        'z_discriminator_s': model_s.z_discriminator,
    }

    output_folder = os.path.join('results_' + str(folding_id) + "_" + "_".join([str(x) for x in inliner_classes]))
    output_folder = os.path.join(cfg.OUTPUT_DIR, output_folder)
    os.makedirs(output_folder, exist_ok=True)

    tracker = LossTracker(output_folder)

    checkpointer = Checkpointer(output_folder,
                                model_dict,
                                {
                                    'encoder_optimizer': encoder_optimizer,
                                    'decoder_optimizer': generator_optimizer,
                                    'scheduler': scheduler,
                                    'tracker': tracker
                                },
                                logger=logger,
                                save=True)

    extra_checkpoint_data = checkpointer.load()
    save_file = os.path.join(checkpointer.folder, "last_checkpoint")
    try:
        with open(save_file, "r") as last_checkpoint:
            f = last_checkpoint.read().strip()
            f = os.path.basename(f)
            checkpointer.tag_last_checkpoint(f)
        extra_checkpoint_data = checkpointer.load()
    except:
        pass

    logger.info("Starting from epoch: %d" % (scheduler.start_epoch()))

    arguments.update(extra_checkpoint_data)

    layer_to_resolution = generator.layer_to_resolution

    train_set, _, _ = make_datasets(cfg, logger, folding_id, inliner_classes)

    rnd = np.random.RandomState(3456)
    latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE)
    samplez = torch.tensor(latents).float().cuda()

    lod2batch = driver.Driver(cfg, logger, world_size, dataset_size=len(train_set))

    sample = next(make_dataloader(train_set, cfg.TRAIN.BATCH_1GPU, torch.cuda.current_device()))
    sample = sample[1]
    sample = sample.view(-1, cfg.MODEL.INPUT_IMAGE_CHANNELS, cfg.MODEL.INPUT_IMAGE_SIZE, cfg.MODEL.INPUT_IMAGE_SIZE)
    # sample = (sample / 127.5 - 1.)

    lod2batch.set_epoch(scheduler.start_epoch(), [encoder_optimizer, generator_optimizer])

    scores_list = []

    try:
        with open(os.path.join(output_folder, "scores.txt"), "r") as f:
            lines = f.readlines()
            lines = [l[:-1].strip() for l in lines]
            lines = [l.split(' ') for l in lines]
            lines = [l for l in lines if len(l) == 2]
            scores_list = [(x[0], float(x[1]))for x in lines]
            # for l in scores_list:
            #     print("%s: %f" % l)
    except FileNotFoundError:
        pass

    def save(epoch):
        score = eval_model_on_valid(cfg, logger, model_s, folding_id, inliner_classes)
        filename = "model_%d" % epoch
        checkpointer.save(filename).wait()
        scores_list.append((filename, score))
        with open(os.path.join(output_folder, "scores.txt"), "w") as f:
            f.writelines([x[0] + " " + str(x[1]) + "\n" for x in scores_list])

    def last_score():
        return 0 if len(scores_list) == 0 else scores_list[-1][1]

    epoch = None
    for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
        model.train()
        lod2batch.set_epoch(epoch, [encoder_optimizer, generator_optimizer])

        logger.info("Batch size: %d, Batch size per GPU: %d, dataset size: %d" % (
                                                                lod2batch.get_batch_size(),
                                                                lod2batch.get_per_GPU_batch_size(),
                                                                len(train_set) * world_size))

        data_loader = make_dataloader(train_set, lod2batch.get_per_GPU_batch_size(), torch.cuda.current_device())
        train_set.shuffle()

        scheduler.set_batch_size(lod2batch.get_batch_size())

        model.train()

        epoch_start_time = time.time()

        i = 0
        for y, x in data_loader:
            x = x.view(x.shape[0], cfg.MODEL.INPUT_IMAGE_CHANNELS, cfg.MODEL.INPUT_IMAGE_SIZE, cfg.MODEL.INPUT_IMAGE_SIZE)
            i += 1
            with torch.no_grad():
                if x.shape[0] != lod2batch.get_per_GPU_batch_size():
                    continue
            encoder_optimizer.zero_grad()
            loss_d, loss_zg = model(x, d_train=True, ae=False)
            tracker.update(dict(loss_d=loss_d, loss_zg=loss_zg))
            (loss_zg + loss_d).backward()
            encoder_optimizer.step()

            generator_optimizer.zero_grad()
            z_discriminator_optimizer.zero_grad()
            loss_g, loss_zd = model(x, d_train=False, ae=False)
            tracker.update(dict(loss_g=loss_g, loss_zd=loss_zd))
            (loss_g + loss_zd).backward()
            generator_optimizer.step()
            z_discriminator_optimizer.step()

            encoder_optimizer.zero_grad()
            generator_optimizer.zero_grad()
            lae = model(x, d_train=True, ae=True)
            tracker.update(dict(lae=lae))
            (lae).backward()
            encoder_optimizer.step()
            generator_optimizer.step()

            betta = 0.5 ** (lod2batch.get_batch_size() / (1000.0))
            model_s.lerp(model, betta)

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time

            # tracker.update(dict(score_a=score_a, score_b=score_b, score_c=score_c))
            tracker.update(dict(score=last_score()))

            lod2batch.step()
            # if lod2batch.is_time_to_save():
            #     checkpointer.save("model_tmp_intermediate_lod%d" % lod_for_saving_model)
            if lod2batch.is_time_to_report():
                save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, cfg, encoder_optimizer,
                            generator_optimizer, output_folder)

        scheduler.step()

        if epoch % 20 == 0:
            save(epoch)

        save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, cfg, encoder_optimizer, generator_optimizer, output_folder)

    logger.info("Training finish!... save training results")
    if epoch is not None:
        save(epoch)

    best_model_name, best_model_score = scores_list[0]
    for model_name, model_score in scores_list:
        if model_score >= best_model_score:
            best_model_name, best_model_score = model_name, model_score

    checkpointer.tag_best_checkpoint(best_model_name)
Ejemplo n.º 11
0
              channels=cfg.MODEL.CHANNELS,
              generator=cfg.MODEL.GENERATOR,
              encoder=cfg.MODEL.ENCODER)

model.cuda()
model.eval()
model.requires_grad_(False)

decoder = model.decoder
encoder = model.encoder
mapping_tl = model.mapping_tl
mapping_fl = model.mapping_fl
dlatent_avg = model.dlatent_avg

logger.info("Trainable parameters generator:")
print(count_parameters(decoder))
logger.info("Trainable parameters discriminator:")
print(count_parameters(encoder))

model_dict = {
    'discriminator_s': encoder,
    'generator_s': decoder,
    'mapping_tl_s': mapping_tl,
    'mapping_fl_s': mapping_fl,
    'dlatent_avg': dlatent_avg
}

### LOAD
checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

Ejemplo n.º 12
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)

    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl

    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    decoder = nn.DataParallel(decoder)

    im_size = 2**(cfg.MODEL.LAYER_COUNT + 1)
    with torch.no_grad():
        draw_uncurated_result_figure(cfg,
                                     'make_figures/output/%s/generations.jpg' %
                                     cfg.NAME,
                                     model,
                                     cx=0,
                                     cy=0,
                                     cw=im_size,
                                     ch=im_size,
                                     rows=6,
                                     lods=[0, 0, 0, 1, 1, 2],
                                     seed=5)
Ejemplo n.º 13
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)

    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :,
                                                            np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, 1.2 * ones,
                            ones)
        x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    logger.info("Evaluating PPL metric")

    decoder = nn.DataParallel(decoder)

    with torch.no_grad():
        ppl = PPL(cfg,
                  num_samples=50000,
                  epsilon=1e-4,
                  space='w',
                  sampling='full',
                  minibatch_size=16 * torch.cuda.device_count())
        ppl.evaluate(logger,
                     mapping_fl,
                     decoder,
                     cfg.DATASET.MAX_RESOLUTION_LEVEL - 2,
                     celeba_style=cfg.PPL_CELEBA_ADJUSTMENT)

    with torch.no_grad():
        ppl = PPL(cfg,
                  num_samples=50000,
                  epsilon=1e-4,
                  space='w',
                  sampling='end',
                  minibatch_size=16 * torch.cuda.device_count())
        ppl.evaluate(logger,
                     mapping_fl,
                     decoder,
                     cfg.DATASET.MAX_RESOLUTION_LEVEL - 2,
                     celeba_style=cfg.PPL_CELEBA_ADJUSTMENT)
Ejemplo n.º 14
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :,
                                                            np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    path = cfg.DATASET.SAMPLES_PATH
    im_size = 2**(cfg.MODEL.LAYER_COUNT + 1)

    def do_attribute_traversal(path, attrib_idx, start, end):
        img = np.asarray(Image.open(path))
        if img.shape[2] == 4:
            img = img[:, :, :3]
        im = img.transpose((2, 0, 1))
        x = torch.tensor(np.asarray(im, dtype=np.float32),
                         device='cpu',
                         requires_grad=True).cuda() / 127.5 - 1.
        if x.shape[0] == 4:
            x = x[:3]
        factor = x.shape[2] // im_size
        if factor != 1:
            x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
        assert x.shape[2] == im_size
        _latents = encode(x[None, ...].cuda())
        latents = _latents[0, 0]

        latents -= model.dlatent_avg.buff.data[0]

        w0 = torch.tensor(np.load("principal_directions/direction_%d.npy" %
                                  attrib_idx),
                          dtype=torch.float32)

        attr0 = (latents * w0).sum()

        latents = latents - attr0 * w0

        def update_image(w):
            with torch.no_grad():
                w = w + model.dlatent_avg.buff.data[0]
                w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers,
                                              1)

                layer_idx = torch.arange(
                    model.mapping_fl.num_layers)[np.newaxis, :, np.newaxis]
                cur_layers = (7 + 1) * 2
                mixing_cutoff = cur_layers
                styles = torch.where(layer_idx < mixing_cutoff, w, _latents[0])

                x_rec = decode(styles)
                return x_rec

        traversal = []

        r = 7
        inc = (end - start) / (r - 1)

        for i in range(r):
            W = latents + w0 * (attr0 + start)
            im = update_image(W)

            traversal.append(im)
            attr0 += inc
        res = torch.cat(traversal)

        indices = [0, 1, 2, 3, 4, 10, 11, 17, 19]
        labels = [
            "gender",
            "smile",
            "attractive",
            "wavy-hair",
            "young",
            "big_lips",
            "big_nose",
            "chubby",
            "glasses",
        ]
        save_image(res * 0.5 + 0.5,
                   "make_figures/output/%s/traversal_%s.jpg" %
                   (cfg.NAME, labels[indices.index(attrib_idx)]),
                   pad_value=1)

    do_attribute_traversal(path + '/00049.png', 0, 0.6, -34)
    do_attribute_traversal(path + '/00125.png', 1, -3, 15.0)
    do_attribute_traversal(path + '/00057.png', 3, -2, 30.0)
    do_attribute_traversal(path + '/00031.png', 4, -10, 30.0)
    do_attribute_traversal(path + '/00088.png', 10, -0.3, 30.0)
    do_attribute_traversal(path + '/00004.png', 11, -25, 20.0)
    do_attribute_traversal(path + '/00012.png', 17, -40, 40.0)
    do_attribute_traversal(path + '/00017.png', 19, 0, 30.0)
Ejemplo n.º 15
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :,
                                                            np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, 1.0 * ones,
                            ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    path = cfg.DATASET.SAMPLES_PATH
    # path = 'dataset_samples/faces/realign1024x1024_paper'

    im_size = 2**(cfg.MODEL.LAYER_COUNT + 1)

    paths = list(os.listdir(path))

    paths = sorted(paths)
    random.seed(5)
    random.shuffle(paths)

    def move_to(list, item, new_index):
        list.remove(item)
        list.insert(new_index, item)

    # move_to(paths, '00026.png', 0)
    # move_to(paths, '00074.png', 1)
    # move_to(paths, '00134.png', 2)
    # move_to(paths, '00036.png', 3)

    def make(paths):
        src = []
        for filename in paths:
            img = np.asarray(Image.open(path + '/' + filename))
            if img.shape[2] == 4:
                img = img[:, :, :3]
            im = img.transpose((2, 0, 1))
            x = torch.tensor(np.asarray(im, dtype=np.float32),
                             requires_grad=True).cuda() / 127.5 - 1.
            if x.shape[0] == 4:
                x = x[:3]
            factor = x.shape[2] // im_size
            if factor != 1:
                x = torch.nn.functional.avg_pool2d(x[None, ...], factor,
                                                   factor)[0]
            assert x.shape[2] == im_size
            src.append(x)

        with torch.no_grad():
            reconstructions = []
            for s in src:
                latents = encode(s[None, ...])
                reconstructions.append(decode(latents).cpu().detach().numpy())
        return src, reconstructions

    def chunker_list(seq, size):
        return list((seq[i::size] for i in range(size)))

    final = chunker_list(paths, 4)
    path0, path1, path2, path3 = final

    path0.reverse()
    path1.reverse()
    path2.reverse()
    path3.reverse()

    src0, rec0 = make(path0)
    src1, rec1 = make(path1)
    src2, rec2 = make(path2)
    src3, rec3 = make(path3)

    initial_resolution = im_size

    lods_down = 1
    padding_step = 4

    width = 0
    height = 0

    current_padding = 0

    final_resolution = initial_resolution
    for _ in range(lods_down):
        final_resolution /= 2

    for i in range(lods_down + 1):
        width += current_padding * 2**(lods_down - i)
        height += current_padding * 2**(lods_down - i)
        current_padding += padding_step

    width += 2**(lods_down + 1) * final_resolution
    height += (lods_down + 1) * initial_resolution

    width = int(width)
    height = int(height)

    def make_part(current_padding, src, rec):
        canvas = np.ones([3, height + 20, width + 10])

        padd = 0

        initial_padding = current_padding

        height_padding = 0

        for i in range(lods_down + 1):
            for x in range(2**i):
                for y in range(2**i):
                    try:
                        ims = src.pop()
                        imr = rec.pop()[0]
                        ims = ims.cpu().detach().numpy()
                        imr = imr

                        res = int(initial_resolution / 2**i)

                        ims = resize(ims, (3, initial_resolution / 2**i,
                                           initial_resolution / 2**i))
                        imr = resize(imr, (3, initial_resolution / 2**i,
                                           initial_resolution / 2**i))

                        place(
                            canvas, ims,
                            current_padding + x * (2 * res + current_padding),
                            i * initial_resolution + height_padding + y *
                            (res + current_padding))

                        place(
                            canvas, imr, current_padding + res + x *
                            (2 * res + current_padding),
                            i * initial_resolution + height_padding + y *
                            (res + current_padding))

                    except IndexError:
                        return canvas

            height_padding += initial_padding * 2

            current_padding -= padding_step
            padd += padding_step
        return canvas

    canvas = [
        make_part(current_padding, src0, rec0),
        make_part(current_padding, src1, rec1),
        make_part(current_padding, src2, rec2),
        make_part(current_padding, src3, rec3)
    ]

    canvas = np.concatenate(canvas, axis=2)

    print('Saving image')
    save_path = 'make_figures/output/%s/reconstructions_multiresolution.png' % cfg.NAME
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    save_image(torch.Tensor(canvas), save_path)
Ejemplo n.º 16
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
        style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER)

    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_d
    mapping_fl = model.mapping_f

    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                {},
                                logger=logger,
                                save=False)

    checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    decoder = nn.DataParallel(decoder)

    im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1)
    with torch.no_grad():
        rnd = np.random.RandomState(0)
        latents = rnd.randn(1, 512)
        samplez = torch.tensor(latents).float().cuda()
        image = model.generate(8, 1, samplez, 1, mixing=True)
         ls   = model.encode(image,8,1)
         x    = model.decoder(x, 8, 1, noise=True)
        save_image(samplez,'1.png')
        save_image(x,'2.png')
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :,
                                                            np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    rnd = np.random.RandomState(5)
    latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)

    dataset = TFRecordsDataset(cfg,
                               logger,
                               rank=0,
                               world_size=1,
                               buffer_size_mb=10,
                               channels=cfg.MODEL.CHANNELS,
                               train=False)

    dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 10)
    b = iter(make_dataloader(cfg, logger, dataset, 10, 0, numpy=True))

    def make(sample):
        canvas = []
        with torch.no_grad():
            for img in sample:
                x = torch.tensor(np.asarray(img, dtype=np.float32),
                                 device='cpu',
                                 requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]
                latents = encode(x[None, ...].cuda())
                f = decode(latents)
                r = torch.cat([x[None, ...].detach().cpu(),
                               f.detach().cpu()],
                              dim=3)
                canvas.append(r)
        return canvas

    sample = next(b)
    canvas = make(sample)
    canvas = torch.cat(canvas, dim=0)

    save_image(canvas * 0.5 + 0.5,
               'make_figures/reconstructions_ffhq_real_1.png',
               nrow=2,
               pad_value=1.0)

    sample = next(b)
    canvas = make(sample)
    canvas = torch.cat(canvas, dim=0)

    save_image(canvas * 0.5 + 0.5,
               'make_figures/reconstructions_ffhq_real_2.png',
               nrow=2,
               pad_value=1.0)
Ejemplo n.º 18
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_d
    mapping_fl = model.mapping_f
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_f.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :,
                                                            np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    path = 'dataset_samples/faces/pioneer256x256'

    paths = list(os.listdir(path))

    def make(paths):
        with torch.no_grad():
            for filename in paths:
                img = np.asarray(Image.open(path + '/' + filename))
                if img.shape[2] == 4:
                    img = img[:, :, :3]
                im = img.transpose((2, 0, 1))

                x = torch.tensor(np.asarray(im, dtype=np.float32),
                                 device='cpu',
                                 requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]

                while x.shape[2] != model.decoder.layer_to_resolution[6]:
                    x = F.avg_pool2d(x, 2, 2)

                latents = encode(x[None, ...].cuda())
                f = decode(latents)
                r = torch.cat([x[None, ...].detach().cpu(),
                               f.detach().cpu()],
                              dim=3)
                os.makedirs('make_figures/output/pioneer/', exist_ok=True)
                save_image(f.detach().cpu() * 0.5 + 0.5,
                           'make_figures/output/pioneer/%s_alae.png' %
                           filename[:-9],
                           nrow=1,
                           pad_value=1.0)

    make(paths)
Ejemplo n.º 19
0
def _main(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                {},
                                logger=logger,
                                save=False)

    extra_checkpoint_data = checkpointer.load()
    last_epoch = list(extra_checkpoint_data['auxiliary']['scheduler'].values())[0]['last_epoch']
    logger.info("Model trained for %d epochs" % last_epoch)

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        layer_count = cfg.MODEL.LAYER_COUNT

        zlist = []
        for i in range(x.shape[0]):
            Z, _ = model.encode(x[i][None, ...], layer_count - 1, 1)
            zlist.append(Z)
        Z = torch.cat(zlist)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        decoded = []
        for i in range(x.shape[0]):
            r = model.decoder(x[i][None, ...], layer_count - 1, 1, noise=True)
            decoded.append(r)
        return torch.cat(decoded)

    path = cfg.DATASET.STYLE_MIX_PATH
    im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1)

    src_originals = []
    for i in range(src_len):
        try:
            im = np.asarray(Image.open(os.path.join(path, 'src/%d.png' % i)).resize((1024,1024)))
        except FileNotFoundError:
            im = np.asarray(Image.open(os.path.join(path, 'src/%d.jpg' % i)).resize((1024,1024)))
        im = im.transpose((2, 0, 1))
        x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=True).cuda() / 127.5 - 1.
        if x.shape[0] == 4:
            x = x[:3]
        factor = x.shape[2] // im_size
        if factor != 1:
            x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
        assert x.shape[2] == im_size
        src_originals.append(x)
    src_originals = torch.stack([x for x in src_originals])
    dst_originals = []
    for i in range(dst_len):
        try:
            im = np.asarray(Image.open(os.path.join(path, 'dst/%d.png' % i)).resize((1024,1024)))
        except FileNotFoundError:
            im = np.asarray(Image.open(os.path.join(path, 'dst/%d.jpg' % i)).resize((1024,1024)))
        im = im.transpose((2, 0, 1))
        x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=True).cuda() / 127.5 - 1.
        if x.shape[0] == 4:
            x = x[:3]
        factor = x.shape[2] // im_size
        if factor != 1:
            x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
        assert x.shape[2] == im_size
        dst_originals.append(x)
    dst_originals = torch.stack([x for x in dst_originals])

    src_latents = encode(src_originals)
    src_images = decode(src_latents)

    dst_latents = encode(dst_originals)
    dst_images = decode(dst_latents)

    canvas = np.zeros([3, im_size * (dst_len + 1), im_size * (src_len + 1)])

    os.makedirs('style_mixing/output/%s/' % cfg.NAME, exist_ok=True)

    for i in range(src_len):
        save_image(src_originals[i] * 0.5 + 0.5, 'style_mixing/output/%s/source_%d.png' % (cfg.NAME, i))
        place(canvas, src_originals[i], 1 + i, 0)

    for i in range(dst_len):
        save_image(dst_originals[i] * 0.5 + 0.5, 'style_mixing/output/%s/dst_coarse_%d.png' % (cfg.NAME, i))
        place(canvas, dst_originals[i], 0, 1 + i)

    style_ranges = [range(0, 4)] * 3 + [range(4, 8)] * 2 + [range(8, layer_count * 2)]

    def mix_styles(style_src, style_dst, r):
        style = style_dst.clone()
        style[:, r] = style_src[:, r]
        return style

    for row in range(dst_len):
        row_latents = torch.stack([dst_latents[row]] * src_len)
        style = mix_styles(src_latents, row_latents, style_ranges[row])
        rec = model.decoder(style, layer_count - 1, 1, noise=True)
        for j in range(rec.shape[0]):
            save_image(rec[j] * 0.5 + 0.5, 'style_mixing/output/%s/rec_coarse_%d_%d.png' % (cfg.NAME, row, j))
            place(canvas, rec[j], 1 + j, 1 + row)

    save_image(torch.Tensor(canvas), 'style_mixing/output/%s/stylemix.png' % cfg.NAME)
Ejemplo n.º 20
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :,
                                                            np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    rnd = np.random.RandomState(5)
    latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)

    path = cfg.DATASET.SAMPLES_PATH

    paths = list(os.listdir(path))

    paths = sorted(paths)
    random.seed(3456)
    random.shuffle(paths)

    def make(paths):
        canvas = []
        with torch.no_grad():
            for filename in paths:
                img = np.asarray(Image.open(path + '/' + filename))
                if img.shape[2] == 4:
                    img = img[:, :, :3]
                im = img.transpose((2, 0, 1))
                x = torch.tensor(np.asarray(im, dtype=np.float32),
                                 device='cpu',
                                 requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]
                latents = encode(x[None, ...].cuda())
                f = decode(latents)
                r = torch.cat([x[None, ...].detach().cpu(),
                               f.detach().cpu()],
                              dim=3)
                canvas.append(r)
        return canvas

    canvas = make(paths[:40])
    canvas = torch.cat(canvas, dim=0)

    save_image(canvas * 0.5 + 0.5,
               'make_figures/output/reconstructions_bed_1.png',
               nrow=4,
               pad_value=1.0)

    canvas = make(paths[40:80])
    canvas = torch.cat(canvas, dim=0)

    save_image(canvas * 0.5 + 0.5,
               'make_figures/output/reconstructions_bed_2.png',
               nrow=4,
               pad_value=1.0)
Ejemplo n.º 21
0
def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=None,
                  truncation_cutoff=None,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)

    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder

    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        # 'encoder_s': encoder,
        'generator_s': decoder,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg_s': dlatent_avg
    }

    checkpointer = Checkpointer(cfg, model_dict, {}, logger=logger, save=False)

    extra_checkpoint_data = checkpointer.load()
    last_epoch = list(extra_checkpoint_data['auxiliary']
                      ['scheduler'].values())[0]['last_epoch']
    logger.info("Model trained for %d epochs" % last_epoch)

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    logger.info("Evaluating LPIPS metric")

    decoder = nn.DataParallel(decoder)
    encoder = nn.DataParallel(encoder)

    with torch.no_grad():
        ppl = LPIPS(cfg,
                    num_images=10000,
                    minibatch_size=16 * torch.cuda.device_count())
        ppl.evaluate(logger, mapping_fl, decoder, encoder,
                     cfg.DATASET.MAX_RESOLUTION_LEVEL - 2)
Ejemplo n.º 22
0
def sample(cfg, logger):
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=3)
    del model.discriminator
    model.eval()

    #torch.cuda.manual_seed_all(110)

    logger.info("Trainable parameters generator:")
    count_parameters(model.generator)

    if False:
        model_dict = {
            'generator': model.generator,
            'mapping': model.mapping,
            'dlatent_avg': model.dlatent_avg,
        }
    else:
        model_dict = {
            'generator_s': model.generator,
            'mapping_s': model.mapping,
            'dlatent_avg': model.dlatent_avg,
        }

    checkpointer = Checkpointer(cfg, model_dict, logger=logger, save=True)

    file_name = 'karras2019stylegan-ffhq'
    # file_name = 'results/model_final'

    checkpointer.load(file_name=file_name + '.pth')
    # checkpointer.save('final_stripped')

    #sample_b = torch.randn(1, cfg.MODEL.LATENT_SPACE_SIZE).view(-1, cfg.MODEL.LATENT_SPACE_SIZE)

    # for i in range(100):
    #     if i % 20 == 0:
    #         sample_a = sample_b
    #         sample_b = torch.randn(1, cfg.MODEL.LATENT_SPACE_SIZE).view(-1, cfg.MODEL.LATENT_SPACE_SIZE)
    #     x = (i % 20) / 20.0
    #     sample = sample_a * (1.0 - x) + sample_b * x
    #     save_sample(model, sample, i)

    print(model.generator.get_statistics(8))
    # print(model.discriminator.get_statistics(8))

    ctx = bimpy.Context()

    ctx.init(1800, 1600, "Styles")

    rnd = np.random.RandomState(5)
    latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)
    sample = torch.tensor(latents).float().cuda()

    def update_image(sample):
        with torch.no_grad():
            model.eval()
            x_rec = model.generate(model.generator.layer_count - 1,
                                   1,
                                   z=sample)
            resultsample = ((x_rec * 0.5 + 0.5) * 255).type(torch.long).clamp(
                0, 255)
            resultsample = resultsample.cpu()[0, :, :, :]

            return resultsample.type(torch.uint8).transpose(0,
                                                            2).transpose(0, 1)

    im = update_image(sample)
    print(im.shape)
    im = bimpy.Image(im)

    while (not ctx.should_close()):
        with ctx:
            im = bimpy.Image(update_image(sample))
            bimpy.image(im)
            # if bimpy.button('Ok'):
            if bimpy.button('NEXT'):
                latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)
                sample = torch.tensor(latents).float().cuda()
                # im = bimpy.Image(update_image(sample))
            #bimpy.set_window_font_scale(2.0)

    exit()

    rnd = np.random.RandomState(111011)
    latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)
    sample = torch.tensor(latents).float().cuda(
    )  # torch.randn(16, cfg.MODEL.LATENT_SPACE_SIZE).view(-1, cfg.MODEL.LATENT_SPACE_SIZE)
    save_sample(model, sample, 0)

    im_count = 16
    canvas = np.zeros([3, im_size * (im_count + 2), im_size * (im_count + 2)])
    cut_layer_b = 0
    cut_layer_e = 2

    styles = model.mapping(sample)
    styles = list(styles.split(1, 1))

    for i in range(im_count):
        torch.cuda.manual_seed_all(110)
        style = [x[i] for x in styles]
        style = torch.cat(style, dim=0)[None, ...]
        rec = model.generator.decode(style, cfg.MODEL.LAYER_COUNT - 1, 0.7)
        place(canvas, rec[0], 1, 2 + i)

        place(canvas, rec[0], 2 + i, 1)

    for i in range(im_count):
        for j in range(im_count):
            style_a = [x[i] for x in styles[:cut_layer_b]]
            style_b = [x[j] for x in styles[cut_layer_b:cut_layer_e]]
            style_c = [x[i] for x in styles[cut_layer_e:]]
            style = style_a + style_b + style_c
            torch.cuda.manual_seed_all(110)
            style = torch.cat(style, dim=0)[None, ...]
            rec = model.generator.decode(style, cfg.MODEL.LAYER_COUNT - 1, 0.7)
            place(canvas, rec[0], 2 + i, 2 + j)

    save_image(torch.Tensor(canvas), 'reconstruction.png')
Ejemplo n.º 23
0
def train(cfg, logger, local_rank, world_size, distributed):
    torch.cuda.set_device(local_rank)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
                  style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)
    model.cuda(local_rank)
    model.train()

    if local_rank == 0:
        model_s = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                        layer_count=cfg.MODEL.LAYER_COUNT,
                        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                        channels=cfg.MODEL.CHANNELS,
                        generator=cfg.MODEL.GENERATOR,
                        encoder=cfg.MODEL.ENCODER)
        model_s.cuda(local_rank)
        model_s.eval()
        model_s.requires_grad_(False)

    if distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            broadcast_buffers=False,
            bucket_cap_mb=25,
            find_unused_parameters=True)
        model.device_ids = None

        decoder = model.module.decoder
        encoder = model.module.encoder
        discriminator = model.module.discriminator
        mapping_fl = model.module.mapping_fl
        dlatent_avg = model.module.dlatent_avg
    else:
        decoder = model.decoder
        encoder = model.encoder
        discriminator = model.discriminator
        mapping_fl = model.mapping_fl
        dlatent_avg = model.dlatent_avg

    count_param_override.print = lambda a: logger.info(a)

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    decoder_optimizer = LREQAdam([{
        'params': decoder.parameters()
    }, {
        'params': mapping_fl.parameters()
    }],
                                 lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                 betas=(cfg.TRAIN.ADAM_BETA_0,
                                        cfg.TRAIN.ADAM_BETA_1),
                                 weight_decay=0)

    encoder_optimizer = LREQAdam([
        {
            'params': encoder.parameters()
        },
    ],
                                 lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                 betas=(cfg.TRAIN.ADAM_BETA_0,
                                        cfg.TRAIN.ADAM_BETA_1),
                                 weight_decay=0)

    discriminator_optimizer = LREQAdam([
        {
            'params': discriminator.parameters()
        },
    ],
                                       lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                       betas=(cfg.TRAIN.ADAM_BETA_0,
                                              cfg.TRAIN.ADAM_BETA_1),
                                       weight_decay=0)

    scheduler = ComboMultiStepLR(optimizers={
        'encoder_optimizer': encoder_optimizer,
        'discriminator_optimizer': discriminator_optimizer,
        'decoder_optimizer': decoder_optimizer
    },
                                 milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,
                                 gamma=cfg.TRAIN.LEARNING_DECAY_RATE,
                                 reference_batch_size=32,
                                 base_lr=cfg.TRAIN.LEARNING_RATES)

    model_dict = {
        'discriminator': discriminator,
        'encoder': encoder,
        'generator': decoder,
        'mapping_fl': mapping_fl,
        'dlatent_avg': dlatent_avg
    }
    if local_rank == 0:
        model_dict['discriminator_s'] = model_s.discriminator
        model_dict['encoder_s'] = model_s.encoder
        model_dict['generator_s'] = model_s.decoder
        model_dict['mapping_fl_s'] = model_s.mapping_fl

    tracker = LossTracker(cfg.OUTPUT_DIR)

    checkpointer = Checkpointer(
        cfg,
        model_dict, {
            'encoder_optimizer': encoder_optimizer,
            'discriminator_optimizer': discriminator_optimizer,
            'decoder_optimizer': decoder_optimizer,
            'scheduler': scheduler,
            'tracker': tracker
        },
        logger=logger,
        save=local_rank == 0)

    extra_checkpoint_data = checkpointer.load()
    logger.info("Starting from epoch: %d" % (scheduler.start_epoch()))

    arguments.update(extra_checkpoint_data)

    layer_to_resolution = decoder.layer_to_resolution

    dataset = TFRecordsDataset(cfg,
                               logger,
                               rank=local_rank,
                               world_size=world_size,
                               buffer_size_mb=1024,
                               channels=cfg.MODEL.CHANNELS)

    rnd = np.random.RandomState(3456)
    latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE)
    samplez = torch.tensor(latents).float().cuda()

    lod2batch = lod_driver.LODDriver(cfg,
                                     logger,
                                     world_size,
                                     dataset_size=len(dataset) * world_size)

    if cfg.DATASET.SAMPLES_PATH:
        path = cfg.DATASET.SAMPLES_PATH
        src = []
        with torch.no_grad():
            for filename in list(os.listdir(path))[:32]:
                img = np.asarray(Image.open(os.path.join(path, filename)))
                if img.shape[2] == 4:
                    img = img[:, :, :3]
                im = img.transpose((2, 0, 1))
                x = torch.tensor(np.asarray(im, dtype=np.float32),
                                 requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]
                src.append(x)
            sample = torch.stack(src)
    else:
        dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 32)
        sample = next(make_dataloader(cfg, logger, dataset, 32, local_rank))
        sample = (sample / 127.5 - 1.)

    lod2batch.set_epoch(scheduler.start_epoch(),
                        [encoder_optimizer, decoder_optimizer])

    for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
        model.train()
        lod2batch.set_epoch(epoch, [encoder_optimizer, decoder_optimizer])

        logger.info(
            "Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d"
            % (lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(),
               lod2batch.lod, 2**lod2batch.get_lod_power2(), 2**
               lod2batch.get_lod_power2(), lod2batch.get_blend_factor(),
               len(dataset) * world_size))

        dataset.reset(lod2batch.get_lod_power2(),
                      lod2batch.get_per_GPU_batch_size())
        batches = make_dataloader(cfg, logger, dataset,
                                  lod2batch.get_per_GPU_batch_size(),
                                  local_rank)

        scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod)

        model.train()

        need_permute = False
        epoch_start_time = time.time()

        i = 0
        with torch.autograd.profiler.profile(use_cuda=True,
                                             enabled=False) as prof:
            for x_orig in tqdm(batches):
                i += 1
                with torch.no_grad():
                    if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size():
                        continue
                    if need_permute:
                        x_orig = x_orig.permute(0, 3, 1, 2)
                    x_orig = (x_orig / 127.5 - 1.)

                    blend_factor = lod2batch.get_blend_factor()

                    needed_resolution = layer_to_resolution[lod2batch.lod]
                    x = x_orig

                    if lod2batch.in_transition:
                        needed_resolution_prev = layer_to_resolution[
                            lod2batch.lod - 1]
                        x_prev = F.avg_pool2d(x_orig, 2, 2)
                        x_prev_2x = F.interpolate(x_prev, needed_resolution)
                        x = x * blend_factor + x_prev_2x * (1.0 - blend_factor)

                x.requires_grad = True

                loss_d = model(x,
                               lod2batch.lod,
                               blend_factor,
                               d_train=True,
                               ae=False)
                tracker.update(dict(loss_d=loss_d))
                loss_d.backward()
                discriminator_optimizer.step()
                decoder_optimizer.zero_grad()
                discriminator_optimizer.zero_grad()

                loss_g = model(x,
                               lod2batch.lod,
                               blend_factor,
                               d_train=False,
                               ae=False)
                tracker.update(dict(loss_g=loss_g))
                loss_g.backward()
                decoder_optimizer.step()
                decoder_optimizer.zero_grad()
                discriminator_optimizer.zero_grad()

                lae = model(x,
                            lod2batch.lod,
                            blend_factor,
                            d_train=True,
                            ae=True)
                tracker.update(dict(lae=lae))
                (lae).backward()
                encoder_optimizer.step()
                decoder_optimizer.step()
                encoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()

                if local_rank == 0:
                    betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0))
                    model_s.lerp(model, betta)

                epoch_end_time = time.time()
                per_epoch_ptime = epoch_end_time - epoch_start_time

                lod_for_saving_model = lod2batch.lod
                lod2batch.step()
                if local_rank == 0:
                    if lod2batch.is_time_to_save():
                        checkpointer.save("model_tmp_intermediate_lod%d" %
                                          lod_for_saving_model)
                    if lod2batch.is_time_to_report():
                        save_sample(lod2batch, tracker, sample, samplez, x,
                                    logger, model_s, cfg, encoder_optimizer,
                                    decoder_optimizer)

        scheduler.step()

        if local_rank == 0:
            checkpointer.save("model_tmp_lod%d" % lod_for_saving_model)
            save_sample(lod2batch, tracker, sample, samplez, x, logger,
                        model_s, cfg, encoder_optimizer, decoder_optimizer)

    logger.info("Training finish!... save training results")
    if local_rank == 0:
        checkpointer.save("model_final").wait()
Ejemplo n.º 24
0
def train(cfg, logger, gpu_id=0):
    torch.cuda.set_device(gpu_id)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
                  style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=3,
                  truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                  truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF)
    model.cuda(gpu_id)
    model.train()

    generator = model.generator
    discriminator = model.discriminator
    mapping = model.mapping
    dlatent_avg = model.dlatent_avg

    count_parameters.print = lambda a: logger.info(
        a)  #将该对象的print函数转换为logger.info

    logger.info("Trainable parameters generator:")
    count_parameters(generator)

    logger.info("Trainable parameters discriminator:")
    count_parameters(discriminator)

    generator_optimizer = LREQAdam([{
        'params': generator.parameters()
    }, {
        'params': mapping.parameters()
    }],
                                   lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                   betas=(cfg.TRAIN.ADAM_BETA_0,
                                          cfg.TRAIN.ADAM_BETA_1),
                                   weight_decay=0)

    discriminator_optimizer = LREQAdam([{
        'params': discriminator.parameters()
    }],
                                       lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                       betas=(cfg.TRAIN.ADAM_BETA_0,
                                              cfg.TRAIN.ADAM_BETA_1),
                                       weight_decay=0)

    scheduler = ComboMultiStepLR(
        optimizers={
            'generator': generator_optimizer,
            'discriminator': discriminator_optimizer
        },
        milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,  # []
        gamma=cfg.TRAIN.LEARNING_DECAY_RATE,  # 0.1
        reference_batch_size=32,
        base_lr=cfg.TRAIN.LEARNING_RATES)  # 0.002

    model_dict = {
        'discriminator': discriminator,
        'generator': generator,
        'mapping': mapping,
        'dlatent_avg': dlatent_avg
    }

    tracker = LossTracker(cfg.OUTPUT_DIR)

    checkpointer = Checkpointer(
        cfg,
        model_dict, {
            'generator_optimizer': generator_optimizer,
            'discriminator_optimizer': discriminator_optimizer,
            'scheduler': scheduler,
            'tracker': tracker
        },
        logger=logger,
        save=gpu_id == 0)

    checkpointer.load()
    logger.info("Starting from epoch: %d" % (scheduler.start_epoch()))

    layer_to_resolution = generator.layer_to_resolution  #[4, 8, 16, 32, 64, 128]

    dataset = TFRecordsDataset(cfg, logger, buffer_size_mb=1024)

    rnd = np.random.RandomState(3456)
    latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE)
    sample = torch.tensor(latents).float().cuda()

    lod2batch = lod_driver.LODDriver(
        cfg, logger, dataset_size=len(dataset))  #一个可以返回各类训练参数(param)的对象

    for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
        model.train()
        lod2batch.set_epoch(epoch,
                            [generator_optimizer, discriminator_optimizer])

        logger.info(
            "Batch size: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d" %
            (lod2batch.get_batch_size(), lod2batch.lod, 2**
             lod2batch.get_lod_power2(), 2**lod2batch.get_lod_power2(),
             lod2batch.get_blend_factor(), len(dataset)))

        dataset.reset(lod2batch.get_lod_power2(), lod2batch.get_batch_size())
        print('pass-------------------------1')
        batches = make_dataloader(
            cfg, logger, dataset,
            lod2batch.get_batch_size())  # 一个数据集分为多个batch,一个batch有n长图片
        print('pass-------------------------2')
        scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod)
        print('pass-------------------------3')
        need_permute = False

        for x_orig in tqdm(batches):  # x_orig:[-1,c,w,h]
            with torch.no_grad():
                if x_orig.shape[0] != lod2batch.get_batch_size():
                    continue
                if need_permute:
                    x_orig = x_orig.permute(0, 3, 1, 2)
                x_orig = (x_orig / 127.5 - 1.)

                blend_factor = lod2batch.get_blend_factor()

                needed_resolution = layer_to_resolution[lod2batch.lod]
                x = x_orig  #图片数据(分辨率递增)

                if lod2batch.in_transition:
                    needed_resolution_prev = layer_to_resolution[lod2batch.lod
                                                                 - 1]
                    x_prev = F.avg_pool2d(x_orig, 2, 2)
                    x_prev_2x = F.interpolate(x_prev, needed_resolution)
                    x = x * blend_factor + x_prev_2x * (1.0 - blend_factor)
            x.requires_grad = True

            discriminator_optimizer.zero_grad()
            loss_d = model(x, lod2batch.lod, blend_factor, d_train=True)
            tracker.update(dict(loss_d=loss_d))
            loss_d.backward()
            discriminator_optimizer.step()

            betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0))
            model.lerp(model, betta)

            generator_optimizer.zero_grad()
            loss_g = model(x, lod2batch.lod, blend_factor, d_train=False)
            tracker.update(dict(loss_g=loss_g))
            loss_g.backward()
            generator_optimizer.step()

            lod2batch.step()
            if lod2batch.is_time_to_save():
                checkpointer.save("model_tmp_intermediate")
            if lod2batch.is_time_to_report():
                save_sample(lod2batch, tracker, sample, x, logger, model, cfg,
                            discriminator_optimizer, generator_optimizer)
        scheduler.step()

        checkpointer.save("model_tmp")
        save_sample(lod2batch, tracker, sample, x, logger, model, cfg,
                    discriminator_optimizer, generator_optimizer)

    logger.info("Training finish!... save training results")
    checkpointer.save("model_final").wait()