def load_decoder(checkpoint_path):
    """Load the decoder portion from a model path"""

    checkpoint = torch.load(checkpoint_path)
    decoder = Generator(1024, 512, 8)
    decoder.load_state_dict(get_keys(checkpoint, "decoder"), strict=True)

    return decoder
Exemple #2
0
 def __init__(self, opts):
     super(pSp, self).__init__()
     self.opts = opts
     # Define architecture
     self.encoder = self.set_encoder()
     self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2)
     self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
     # Load weights if needed
     self.load_weights()
Exemple #3
0
 def __init__(self, opts):
     super(pSp, self).__init__()
     self.set_opts(opts)
     # Define architecture
     self.encoder = self.set_encoder()
     self.decoder = Generator(1024, 512, 8)
     self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
     # Load weights if needed
     self.load_weights()
 def __init__(self, opts):
     super(StyleCLIPMapper, self).__init__()
     self.opts = opts
     # Define architecture
     self.mapper = self.set_mapper()
     self.decoder = Generator(self.opts.stylegan_size, 512, 8)
     self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
     # Load weights if needed
     self.load_weights()
Exemple #5
0
 def __init__(self, opts):
     super(pSp, self).__init__()
     self.set_opts(opts)
     self.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
     # Define architecture
     self.encoder = self.set_encoder()
     self.decoder = Generator(self.opts.output_size, 512, 8, channel_multiplier=2)
     self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
     # Load weights if needed
     self.load_weights()
Exemple #6
0
	def __init__(self, opts):
		super(pSp, self).__init__()
		self.set_opts(opts)
		# compute number of style inputs based on the output resolution
		self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
		# Define architecture
		self.encoder = self.set_encoder()
		self.decoder = Generator(self.opts.output_size, 512, 8)
		self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
		# Load weights if needed
		self.load_weights()
Exemple #7
0
    def __init__(self, opts):
        super().__init__()
        self.set_opts(opts)
        # Define architecture
        self.encoder = self.set_encoder()
        self.decoder = Generator(512,
                                 512,
                                 8,
                                 channel_multiplier=2,
                                 c_dim=self.opts.c_dim)

        self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
        # Load weights if needed
        self.load_weights()
 def __init__(self, opts, resize_factor=256):
     super(FingerGen, self).__init__()
     self.opts = None
     self.set_opts(opts)
     # Define architecture
     self.encoder = self.set_encoder()
     self.decoder = Generator(opts.generator_image_size,
                              512,
                              8,
                              is_gray=opts.label_nc)
     self.image_pool = torch.nn.AdaptiveAvgPool2d(
         (resize_factor, resize_factor))
     # Load weights if needed
     self.load_weights()
def main(args):
    ensure_checkpoint_exists(args.ckpt)
    text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
    os.makedirs(args.results_dir, exist_ok=True)

    g_ema = Generator(args.stylegan_size, 512, 8)
    g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
    g_ema.eval()
    g_ema = g_ema.cuda()
    mean_latent = g_ema.mean_latent(4096)

    if args.latent_path:
        latent_code_init = torch.load(args.latent_path).cuda()
    elif args.mode == "edit":
        latent_code_init_not_trunc = torch.randn(1, 512).cuda()
        with torch.no_grad():
            _, latent_code_init, _ = g_ema([latent_code_init_not_trunc],
                                           return_latents=True,
                                           truncation=args.truncation,
                                           truncation_latent=mean_latent)
    else:
        latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1)

    with torch.no_grad():
        img_orig, _ = g_ema([latent_code_init],
                            input_is_latent=True,
                            randomize_noise=False)

    if args.work_in_stylespace:
        with torch.no_grad():
            _, _, latent_code_init = g_ema([latent_code_init],
                                           input_is_latent=True,
                                           return_latents=True)
        latent = [s.detach().clone() for s in latent_code_init]
        for c, s in enumerate(latent):
            if c in STYLESPACE_INDICES_WITHOUT_TORGB:
                s.requires_grad = True
    else:
        latent = latent_code_init.detach().clone()
        latent.requires_grad = True

    clip_loss = CLIPLoss(args)
    id_loss = IDLoss(args)

    if args.work_in_stylespace:
        optimizer = optim.Adam(latent, lr=args.lr)
    else:
        optimizer = optim.Adam([latent], lr=args.lr)

    pbar = tqdm(range(args.step))

    for i in pbar:
        t = i / args.step
        lr = get_lr(t, args.lr)
        optimizer.param_groups[0]["lr"] = lr

        img_gen, _ = g_ema([latent],
                           input_is_latent=True,
                           randomize_noise=False,
                           input_is_stylespace=args.work_in_stylespace)

        c_loss = clip_loss(img_gen, text_inputs)

        if args.id_lambda > 0:
            i_loss = id_loss(img_gen, img_orig)[0]
        else:
            i_loss = 0

        if args.mode == "edit":
            if args.work_in_stylespace:
                l2_loss = sum([((latent_code_init[c] - latent[c])**2).sum()
                               for c in range(len(latent_code_init))])
            else:
                l2_loss = ((latent_code_init - latent)**2).sum()
            loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss
        else:
            loss = c_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_description((f"loss: {loss.item():.4f};"))
        if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0:
            with torch.no_grad():
                img_gen, _ = g_ema([latent],
                                   input_is_latent=True,
                                   randomize_noise=False,
                                   input_is_stylespace=args.work_in_stylespace)

            torchvision.utils.save_image(img_gen,
                                         f"results/{str(i).zfill(5)}.jpg",
                                         normalize=True,
                                         range=(-1, 1))

    if args.mode == "edit":
        final_result = torch.cat([img_orig, img_gen])
    else:
        final_result = img_gen

    return final_result
Exemple #10
0
def main(args):

    text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
    os.makedirs(args.results_dir, exist_ok=True)

    F = PerceptualModel(min_val=-1.0, max_val=1.0)

    g_ema = Generator(1024, 512, 8)
    g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
    g_ema.eval()
    g_ema = g_ema.cuda()
    z_mean = g_ema.mean_latent(4096)
    # z_load = np.load(args.latent_path)
    # z_init = torch.from_numpy(z_load).cuda()
    # print(np.shape(latent_load))
    F_OOM = args.f_oom

    if args.mode == "man":
        z_init = torch.load(args.latent_path).cuda()
    else:
        z_init_not_trunc = torch.randn(1, 512).cuda()
        with torch.no_grad():
            _, z_init = g_ema([z_init_not_trunc],
                              truncation_latent=z_mean,
                              return_latents=True,
                              truncation=0.7)

    x, _ = g_ema([z_init], input_is_latent=True, randomize_noise=False)

    # z = z_init.detach().clone()
    z = z_mean.detach().clone().repeat(1, 18, 1)

    z.requires_grad = True

    clip_loss = CLIPLoss()

    optimizer = optim.Adam([z], lr=args.lr)

    pbar = tqdm(range(args.step))

    for i in pbar:
        t = i / args.step
        lr = get_lr(t, args.lr)
        optimizer.param_groups[0]["lr"] = lr

        x_rec, _ = g_ema([z], input_is_latent=True, randomize_noise=False)
        if not F_OOM:
            loss = 0.0
            # Reconstruction loss.
            loss_pix = torch.mean((x - x_rec)**2)
            loss = loss + loss_pix * args.loss_pix_weight
            log_message = f'loss_pix: {_get_tensor_value(loss_pix):.3f}'

            # Perceptual loss.
            if args.loss_feat_weight:
                x_feat = F.net(x)
                x_rec_feat = F.net(x_rec)
                loss_feat = torch.mean((x_feat - x_rec_feat)**2)
                loss = loss + loss_feat * args.loss_feat_weight
                log_message += f', loss_feat: {_get_tensor_value(loss_feat):.3f}'

            # Regularization loss.
            if args.loss_reg_weight:
                loss_reg = torch.mean((z_init - z)**2)
                # loss_reg = ((z_init - z) ** 2).sum()
                loss = loss + loss_reg * args.loss_reg_weight
                log_message += f', loss_reg: {_get_tensor_value(loss_reg):.3f}'

            # CLIP loss.
            if args.loss_clip_weight:
                loss_clip = clip_loss(x_rec, text_inputs)
                loss = loss + loss_clip[0][0] * args.loss_clip_weight
                log_message += f', loss_clip: {_get_tensor_value(loss_clip[0][0]):.3f}'
        else:
            loss_reg = ((z_init - z)**2).sum()
            loss_clip = clip_loss(x_rec, text_inputs)
            loss = loss_reg + loss_clip[0][
                0] * args.loss_clip_weight  # set loss_clip_weight as 200 in my case.

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_description((f"loss: {loss.item():.4f};"))

    final_result = torch.cat([x, x_rec])
    return final_result
def main(args):
    ensure_checkpoint_exists(args.ckpt)
    text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
    os.makedirs(args.results_dir, exist_ok=True)

    g_ema = Generator(args.stylegan_size, 512, 8)
    g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
    g_ema.eval()
    g_ema = g_ema.cuda()
    mean_latent = g_ema.mean_latent(4096)

    if args.latent_path:
        latent_code_init = torch.load(args.latent_path).cuda()
    elif args.mode == "edit":
        latent_code_init_not_trunc = torch.randn(1, 512).cuda()
        with torch.no_grad():
            _, latent_code_init = g_ema([latent_code_init_not_trunc], return_latents=True,
                                        truncation=args.truncation, truncation_latent=mean_latent)
    else:
        latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1)

    # latent = latent_code_init.detach().clone()
    # latent.requires_grad = True

    # latent_inits = []
    # if args.mode == "edit":
    #     for i in range(200):
    #         latent_code_init_not_trunc = torch.randn(1, 512).cuda()
    #         with torch.no_grad():
    #             _, latent_code_init = g_ema([latent_code_init_not_trunc], return_latents=True,
    #                                         truncation=args.truncation, truncation_latent=mean_latent)
    #             latent_inits.append(latent_code_init)
    # else:
    #     raise NotImplementedError

    delta = torch.randn(latent_code_init.shape).cuda()
    delta.requires_grad = True


    clip_loss = CLIPLoss(args)

    optimizer = optim.Adam([delta], lr=args.lr)

    pbar = tqdm(range(args.step))

    for i in pbar:
        t = i / args.step
        lr = get_lr(t, args.lr)
        optimizer.param_groups[0]["lr"] = lr

        img_gen, _ = g_ema([delta + latent_code_init], input_is_latent=True, randomize_noise=False)

        c_loss = clip_loss(img_gen, text_inputs)

        if args.mode == "edit":
            l2_loss = ((delta) ** 2).sum()
            loss = c_loss + args.l2_lambda * l2_loss
        else:
            loss = c_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_description(
            (
                f"loss: {loss.item():.4f};"
            )
        )
        if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0:
            with torch.no_grad():
                img_gen, _ = g_ema([delta + latent_code_init], input_is_latent=True, randomize_noise=False)

            torchvision.utils.save_image(img_gen, f"results/{str(i).zfill(5)}.png", normalize=True, range=(-1, 1))

    if args.mode == "edit":
        with torch.no_grad():
            img_orig, _ = g_ema([latent_code_init], input_is_latent=True, randomize_noise=False)

        final_result = torch.cat([img_orig, img_gen])
    else:
        final_result = img_gen

    return final_result
    n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = n_gpu > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
        synchronize()

    args.latent = 512
    args.n_mlp = 8

    args.start_iter = 0

    generator = Generator(
        args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier, is_gray=args.convert_to_gray
    ).to(device)
    discriminator = Discriminator(
        args.size, channel_multiplier=args.channel_multiplier, is_gray=args.convert_to_gray
    ).to(device)
    g_ema = Generator(
        args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier, is_gray=args.convert_to_gray
    ).to(device)
    g_ema.eval()
    accumulate(g_ema, generator, 0)

    g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
    d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)

    g_optim = optim.Adam(
        generator.parameters(),