예제 #1
0
def generate_latents(n_latents,
                     ckpt,
                     G_res,
                     noconst=False,
                     latent_dim=512,
                     n_mlp=8,
                     channel_multiplier=2):
    """Generates random, mapped latents

    Args:
        n_latents (int): Number of mapped latents to generate 
        ckpt (str): Generator checkpoint to use
        G_res (int): Generator's training resolution
        noconst (bool, optional): Whether the generator was trained without constant starting layer. Defaults to False.
        latent_dim (int, optional): Size of generator's latent vectors. Defaults to 512.
        n_mlp (int, optional): Number of layers in the generator's mapping network. Defaults to 8.
        channel_multiplier (int, optional): Scaling multiplier for generator's channel depth. Defaults to 2.

    Returns:
        th.tensor: Set of mapped latents
    """
    generator = Generator(
        G_res,
        latent_dim,
        n_mlp,
        channel_multiplier=channel_multiplier,
        constant_input=not noconst,
        checkpoint=ckpt,
    ).cuda()
    zs = th.randn((n_latents, latent_dim), device="cuda")
    latent_selection = generator(zs, map_latents=True).cpu()
    del generator, zs
    gc.collect()
    th.cuda.empty_cache()
    return latent_selection
def load_generator(ckpt, is_stylegan1, G_res, out_size, noconst, latent_dim,
                   n_mlp, channel_multiplier, dataparallel):
    """Loads a StyleGAN 1 or 2 generator"""
    if is_stylegan1:
        generator = G_style(output_size=out_size, checkpoint=ckpt).cuda()
    else:
        generator = Generator(
            G_res,
            latent_dim,
            n_mlp,
            channel_multiplier=channel_multiplier,
            constant_input=not noconst,
            checkpoint=ckpt,
            output_size=out_size,
        ).cuda()
    if dataparallel:
        generator = th.nn.DataParallel(generator)
    return generator
예제 #3
0
    args.name = os.path.splitext(os.path.basename(args.path))[0]

    args.num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    th.backends.cudnn.benchmark = args.cudnn_benchmark
    args.distributed = args.num_gpus > 1

    if args.distributed:
        th.cuda.set_device(args.local_rank)
        th.distributed.init_process_group(backend="nccl", init_method="env://")
        synchronize()

    generator = Generator(
        args.size,
        args.latent_size,
        args.n_mlp,
        channel_multiplier=args.channel_multiplier,
        constant_input=args.constant_input,
    ).to(device)
    discriminator = Discriminator(args.size, channel_multiplier=args.channel_multiplier).to(device)

    if args.log_spec_norm:
        for name, parameter in generator.named_parameters():
            if "weight" in name and parameter.squeeze().dim() > 1:
                mod = generator
                for attr in name.replace(".weight", "").split("."):
                    mod = getattr(mod, attr)
                validation.track_spectral_norm(mod)
        for name, parameter in discriminator.named_parameters():
            if "weight" in name and parameter.squeeze().dim() > 1:
                mod = discriminator
예제 #4
0
    parser.add_argument("--batch", type=int, default=32)
    parser.add_argument("--num_frames", type=int, default=150)
    parser.add_argument("--duration", type=int, default=5)
    parser.add_argument("--const", type=bool, default=False)
    parser.add_argument("--channel_multiplier", type=int, default=2)
    parser.add_argument("--truncation", type=int, default=1.5)

    args = parser.parse_args()
    args.latent = 512
    args.n_mlp = 8

    generator = Generator(
        args.G_res,
        args.latent,
        args.n_mlp,
        channel_multiplier=args.channel_multiplier,
        constant_input=args.const,
        checkpoint=args.ckpt,
        output_size=args.out_size,
    )
    generator = th.nn.DataParallel(generator.cuda())

    with th.no_grad():
        latents = th.randn((args.num_frames, 512)).cuda()
        latents = generator(latents, map_latents=True).cpu().numpy()
        latents = ndi.gaussian_filter(latents, [5, 0, 0])
        latents = th.from_numpy(latents).cuda()
        print("latent shape: ", latents.shape)

        noise = [
            np.random.normal(
예제 #5
0
    parser.add_argument("--size", type=int, default=1024)
    parser.add_argument("--sample", type=int, default=1)
    parser.add_argument("--pics", type=int, default=20)
    parser.add_argument("--truncation", type=float, default=1)
    parser.add_argument("--truncation_mean", type=int, default=4096)
    parser.add_argument("--ckpt",
                        type=str,
                        default="stylegan2-ffhq-config-f.pt")
    parser.add_argument("--channel_multiplier", type=int, default=2)

    args = parser.parse_args()

    args.latent = 512
    args.n_mlp = 8

    g_ema = Generator(args.size,
                      args.latent,
                      args.n_mlp,
                      channel_multiplier=args.channel_multiplier).to(device)
    checkpoint = torch.load(args.ckpt)

    g_ema.load_state_dict(checkpoint["g_ema"])

    if args.truncation < 1:
        with torch.no_grad():
            mean_latent = g_ema.mean_latent(args.truncation_mean)
    else:
        mean_latent = None

    generate(args, g_ema, device, mean_latent)
예제 #6
0
            args.num_gpus = int(
                os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
            th.backends.cudnn.benchmark = args.cudnn_benchmark
            args.distributed = args.num_gpus > 1

            if args.distributed:
                th.cuda.set_device(args.local_rank)
                th.distributed.init_process_group(backend="nccl",
                                                  init_method="env://")
                synchronize()

            generator = Generator(
                args.size,
                args.latent_size,
                args.n_mlp,
                channel_multiplier=args.channel_multiplier,
                constant_input=args.constant_input,
                min_rgb_size=args.min_rgb_size,
            ).to(device, non_blocking=True)
            discriminator = Discriminator(
                args.size,
                channel_multiplier=args.channel_multiplier,
                use_skip=args.d_skip).to(device, non_blocking=True)

            if args.log_spec_norm:
                for name, parameter in generator.named_parameters():
                    if "weight" in name and parameter.squeeze().dim() > 1:
                        mod = generator
                        for attr in name.replace(".weight", "").split("."):
                            mod = getattr(mod, attr)
                        validation.track_spectral_norm(mod)