def create_blended_model(
    ctx: click.Context,
    lower_res_pkl: str,
    higher_res_pkl: str,
    model_res: Optional[int],
    resolution: Optional[int],
    out: Optional[str],
):

	G_kwargs = dnnlib.EasyDict()

	with dnnlib.util.open_url(lower_res_pkl) as f:
	    lo = legacy.load_network_pkl(f, custom=False, **G_kwargs) # type: ignore
	    lo_G, lo_D, lo_G_ema = lo['G'], lo['D'], lo['G_ema']

	with dnnlib.util.open_url(higher_res_pkl) as f:
	    hi = legacy.load_network_pkl(f, custom=False, **G_kwargs)['G_ema'] # type: ignore

	model_out = blend_models(lo_G_ema, hi, model_res, resolution)
	# for n in model_out.named_parameters():
	#     print(n[0])

	data = dict([('G', None), ('D', None), ('G_ema', None)])
	with open(out, 'wb') as f:
	    data['G'] = lo_G
	    data['D'] = lo_D
	    data['G_ema'] = model_out
	    pickle.dump(data, f)
예제 #2
0
    def __init__(self,
                 domain,
                 ckpt_path=None,
                 load_encoder=False,
                 device='cuda'):
        from . import perturb_settings
        import sys
        sys.path.append('resources/stylegan2-ada-pytorch')
        import legacy
        import dnnlib
        import click

        assert (domain == 'cifar10')
        network_pkl = {
            'cifar10':
            'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl',
        }[domain]
        device = torch.device(device)
        with dnnlib.util.open_url(network_pkl) as f:
            G = legacy.load_network_pkl(f)['G_ema'].to(device)  # type: ignore
        G = G.eval()
        self.generator = G
        self.encoder = None  # no pretrained encoder for cifar10 model
        self.device = device
        self.perturb_settings = perturb_settings.stylegan2_cc_settings[domain]
        self.pca_stats = None
        self.cc_mean_w = None  # mean w latent per class
예제 #3
0
def convert(network_pkl, output_file):
    with dnnlib.util.open_url(network_pkl) as f:
        G_nvidia = legacy.load_network_pkl(f)["G_ema"]

    state_nv = G_nvidia.state_dict()
    n_mapping, n_layers = determine_config(state_nv)

    state_ros = {}

    for i in range(n_mapping):
        state_ros[f"style.{i+1}.weight"] = state_nv[f"mapping.fc{i}.weight"]
        state_ros[f"style.{i+1}.bias"] = state_nv[f"mapping.fc{i}.bias"]

    for i in range(int(n_layers)):
        if i > 0:
            for conv_level in range(2):
                convert_conv(state_ros, state_nv, f"convs.{2*i-2+conv_level}", f"synthesis.b{4*(2**i)}.conv{conv_level}")
                state_ros[f"noises.noise_{2*i-1+conv_level}"] = state_nv[f"synthesis.b{4*(2**i)}.conv{conv_level}.noise_const"].unsqueeze(0).unsqueeze(0)

            convert_to_rgb(state_ros, state_nv, f"to_rgbs.{i-1}", f"synthesis.b{4*(2**i)}")
            convert_blur_kernel(state_ros, state_nv, i-1)
        
        else:
            state_ros[f"input.input"] = state_nv[f"synthesis.b{4*(2**i)}.const"].unsqueeze(0)
            convert_conv(state_ros, state_nv, "conv1", f"synthesis.b{4*(2**i)}.conv1")
            state_ros[f"noises.noise_{2*i}"] = state_nv[f"synthesis.b{4*(2**i)}.conv1.noise_const"].unsqueeze(0).unsqueeze(0)
            convert_to_rgb(state_ros, state_nv, "to_rgb1", f"synthesis.b{4*(2**i)}")

    state_dict = {"g_ema": state_ros}
    torch.save(state_dict, output_file)
예제 #4
0
 def __init__(self, model_path, stylegan_dir, truncation_psi=0.5):
     super().__init__()
     sys.path.insert(1, stylegan_dir)
     import legacy
     with open(model_path, 'rb') as f:
         self.G = legacy.load_network_pkl(f)['G_ema'].cuda().eval()
     self.truncation = truncation_psi
예제 #5
0
    def _prepare(self) -> None:
        """Preparing :py:class:`StyleGAN` includes importing the
        required modules and loading model data.
        """
        super()._prepare()

        # Step 1: Import modules from stylegan repository
        print("importing stylegan2ada from "
              f"'{config.nvlabs_stylegan2ada_pytorch_directory}'")
        nvlabs_directory = str(config.nvlabs_stylegan2ada_pytorch_directory)
        if nvlabs_directory not in sys.path:
            sys.path.insert(0, nvlabs_directory)

        global dnnlib
        import legacy  # This requires torch 1.7.1
        import dnnlib

        sys.path.remove(nvlabs_directory)

        self._device = torch.device('cuda')
        network_pkl = \
            str(config.nvlabs_stylegan2ada_pytorch_directory / "metfaces.pkl")

        with dnnlib.util.open_url(network_pkl) as f:
            network_pkl = legacy.load_network_pkl(f)

        print(list(network_pkl))
        # ['G', 'D', 'G_ema', 'training_set_kwargs', 'augment_pipe']
        self._generator = network_pkl['G_ema'].to(self._device)
        self._discriminator = network_pkl['D'].to(self._device)
예제 #6
0
def run_projection(input_image):
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Load networks.
    print('Loading networks from "%s"...' % checkpoint_path)
    device = torch.device('cuda')
    with dnnlib.util.open_url(checkpoint_path) as fp:
        G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device)  # type: ignore

    # Load target image.
    target_pil = PIL.Image.open(input_image).convert('RGB')
    w, h = target_pil.size
    s = min(w, h)
    target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
    target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
    target_uint8 = np.array(target_pil, dtype=np.uint8)

    # Optimize projection.
    projected_w_steps = project(
        G,
        target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device),  # pylint: disable=not-callable
        num_steps=num_steps,
        device=device,
        verbose=True
    )

    os.makedirs(output_dir, exist_ok=True)

    projected_w = projected_w_steps[-1]
    file_name = input_image.split('/')[-1].split('.')[0]
    np.savez(f'{output_dir}/{file_name}.npz', w=projected_w.unsqueeze(0).cpu().numpy())
예제 #7
0
def generate_images(network_pkl: str, seeds: List[int], truncation_psi: float,
                    noise_mode: str, outdir: str, translate: Tuple[float,
                                                                   float],
                    rotate: float, class_idx: Optional[int]):
    """Generate images using pretrained network pickle.

    Examples:

    \b
    # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
    python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
        --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl

    \b
    # Generate uncurated images with truncation using the MetFaces-U dataset
    python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
        --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
    """

    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cpu')
    with dnnlib.util.open_url(network_pkl) as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device)  # type: ignore

    os.makedirs(outdir, exist_ok=True)

    # Labels.
    label = torch.zeros([1, G.c_dim], device=device)
    if G.c_dim != 0:
        if class_idx is None:
            raise click.ClickException(
                'Must specify class label with --class when using a conditional network'
            )
        label[:, class_idx] = 1
    else:
        if class_idx is not None:
            print(
                'warn: --class=lbl ignored when running on an unconditional network'
            )

    # Generate images.
    for seed_idx, seed in enumerate(seeds):
        print('Generating image for seed %d (%d/%d) ...' %
              (seed, seed_idx, len(seeds)))
        z = torch.from_numpy(np.random.RandomState(seed).randn(
            1, G.z_dim)).to(device).float()

        # Construct an inverse rotation/translation matrix and pass to the generator.  The
        # generator expects this matrix as an inverse to avoid potentially failing numerical
        # operations in the network.
        if hasattr(G.synthesis, 'input'):
            m = make_transform(translate, rotate)
            m = np.linalg.inv(m)
            G.synthesis.input.transform.copy_(torch.from_numpy(m))

        img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
        img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
            torch.uint8)
        PIL.Image.fromarray(img[0].cpu().numpy(),
                            'RGB').save(f'{outdir}/seed{seed:04d}.png')
예제 #8
0
def generate_gif(network_pkl: str, seed: int, num_rows: int, num_cols: int,
                 resolution: int, num_phases: int, transition_frames: int,
                 static_frames: int, truncation_psi: float, noise_mode: str,
                 output: str):
    """Generate gif using pretrained network pickle.

    Examples:

    \b
    python generate_gif.py --output=obama.gif --seed=0 --num-rows=1 --num-cols=8 \\
        --network=https://hanlab.mit.edu/projects/data-efficient-gans/models/DiffAugment-stylegan2-100-shot-obama.pkl
    """
    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cuda')
    with dnnlib.util.open_url(network_pkl) as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device)  # type: ignore

    os.makedirs(os.path.dirname(output), exist_ok=True)

    np.random.seed(seed)

    output_seq = []
    batch_size = num_rows * num_cols
    latent_size = G.z_dim
    latents = [
        np.random.randn(batch_size, latent_size) for _ in range(num_phases)
    ]

    def to_image_grid(outputs):
        outputs = np.reshape(outputs, [num_rows, num_cols, *outputs.shape[1:]])
        outputs = np.concatenate(outputs, axis=1)
        outputs = np.concatenate(outputs, axis=1)
        return Image.fromarray(outputs).resize(
            (resolution * num_cols, resolution * num_rows), Image.ANTIALIAS)

    def generate(dlatents):
        images = G.synthesis(dlatents, noise_mode=noise_mode)
        images = (images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
            torch.uint8).cpu().numpy()
        return to_image_grid(images)

    for i in range(num_phases):
        dlatents0 = G.mapping(
            torch.from_numpy(latents[i - 1]).to(device), None)
        dlatents1 = G.mapping(torch.from_numpy(latents[i]).to(device), None)
        for j in range(transition_frames):
            dlatents = (dlatents0 * (transition_frames - j) +
                        dlatents1 * j) / transition_frames
            output_seq.append(generate(dlatents))
        output_seq.extend([generate(dlatents1)] * static_frames)

    if not output.endswith('.gif'):
        output += '.gif'
    output_seq[0].save(output,
                       save_all=True,
                       append_images=output_seq[1:],
                       optimize=False,
                       duration=50,
                       loop=0)
예제 #9
0
def run_projection(
    network_pkl: str,
    target_fname: str,
    outdir: str,
    save_video: bool,
    seed: int,
    num_steps: int
):
    """Project given image to the latent space of pretrained network pickle.

    Examples:

    \b
    python projector.py --outdir=out --target=~/mytargetimg.png \\
        --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
    """

    np.random.seed(seed)
    torch.manual_seed(seed)

    # Load networks.
    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cuda')
    with dnnlib.util.open_url(network_pkl) as fp:
        G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore

    # Load target image.
    target_pil = PIL.Image.open(os.path.expanduser(target_fname)).convert('RGB')
    w, h = target_pil.size
    s = min(w, h)
    target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
    target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
    target_uint8 = np.array(target_pil, dtype=np.uint8)

    # Optimize projection.
    start_time = perf_counter()
    projected_w_steps = project(
        G,
        target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
        num_steps=num_steps,
        device=device,
        verbose=True
    )
    print(f'Elapsed: {(perf_counter()-start_time):.1f} s')

    # Render debug output: optional video and projected image and W vector.
    os.makedirs(outdir, exist_ok=True)

    # Save final projected frame and W vector.
    phoneme_characters = target_fname.split('_')[-1][:-4]
    # target_pil.save(f'{outdir}/target_{phoneme_characters}.png')
    projected_w = projected_w_steps[-1]
    synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
    synth_image = (synth_image + 1) * (255/2)
    synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
    PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj_{phoneme_characters}.png')
    # np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())
    ls_v = projected_w.unsqueeze(0).cpu().numpy()[:, 1, :].flatten()
    return ls_v
예제 #10
0
def generate_images(
    ctx: click.Context,
    network_pkl: str,
    truncation_psi: float,
    external_truncation_psi: float,
    noise_mode: str,
    num_images: int,
    batch_size: int,
    seed: int,
    outdir: str,
):
    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cuda')
    with dnnlib.util.open_url(network_pkl) as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(
            device).eval()  # type: ignore

    os.makedirs(outdir, exist_ok=True)

    if external_truncation_psi < 1:
        z = torch.randn(10000, G.z_dim, device=device)  # [10000, z_dim]
        w = G.mapping(z, None)  # [10000, num_ws, w_dim]
        w_avg_ext = w[:, 0].mean(dim=0,
                                 keepdim=True).unsqueeze(1)  # [1, 1, w_dim]

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Generate images.
    for batch_idx in tqdm(range((num_images + batch_size - 1) // batch_size)):
        z = torch.randn(batch_size, G.z_dim,
                        device=device)  # [batch_size, z_dim]
        w = G.mapping(
            z, None,
            truncation_psi=truncation_psi)  # [batch_size, num_ws, z_dim]
        if external_truncation_psi < 1:
            w = (1 - external_truncation_psi
                 ) * w_avg_ext + external_truncation_psi * w
        imgs = G.synthesis(w, noise_mode=noise_mode)

        # z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
        # img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
        imgs = (imgs.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
            torch.uint8)
        for i, img in enumerate(imgs):
            image_num = batch_idx * batch_size + i
            if image_num >= num_images:
                break
            PIL.Image.fromarray(
                img.cpu().numpy(),
                'RGB').save(f'{outdir}/img_{image_num:06d}.png')
def generate_images(
    network_pkl: str,
    latent_vectors: Optional[List[int]],
    truncation_psi: float,
    noise_mode: str,
    outdir: str,
    class_idx: Optional[int],
):
    [
        os.remove(outdir + file) for file in os.listdir(outdir)
        if file.endswith('.png')
    ]
    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cuda')
    with dnnlib.util.open_url(network_pkl) as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device)  # type: ignore

    os.makedirs(outdir, exist_ok=True)

    label = torch.zeros([1, G.c_dim], device=device)
    if G.c_dim != 0:
        label[:, class_idx] = 1
    else:
        if class_idx is not None:
            print(
                'warn: --class=lbl ignored when running on an unconditional network'
            )

    counter = 0
    for latent_vector in latent_vectors:
        # Generate images.
        input = np.array(latent_vector).reshape(1, latent_vector.shape[0])

        v = input.flatten()
        for i in range(15):
            v = np.vstack([v, input.flatten()])

        v = v.reshape(1, v.shape[0], v.shape[1])
        z = torch.from_numpy(v).to(device)
        img = G.synthesis(z, noise_mode=noise_mode)
        img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
            torch.uint8)
        PIL.Image.fromarray(img[0].cpu().numpy(),
                            'RGB').save(f'{outdir}/seed{counter:04d}.png')

        # z = torch.from_numpy(input).to(device)
        # img = G(z, label, noise_mode=noise_mode) # truncation_psi=truncation_psi,
        # img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
        # PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{counter}2.png')

        counter += 1
예제 #12
0
def generate_images(
    network_pkl: str,
    seeds: List[int],
    shuffle_seed: Optional[int],
    truncation_psi: float,
    grid: Tuple[int, int],
    num_keyframes: Optional[int],
    w_frames: int,
    output: str,
    class_idx: Optional[int],
):
    """Render a latent vector interpolation video.

    Examples:

    \b
    # Render a 4x2 grid of interpolations for seeds 0 through 31.
    python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\
        --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl

    Animation length and seed keyframes:

    The animation length is either determined based on the --seeds value or explicitly
    specified using the --num-keyframes option.

    When num keyframes is specified with --num-keyframes, the output video length
    will be 'num_keyframes*w_frames' frames.

    If --num-keyframes is not specified, the number of seeds given with
    --seeds must be divisible by grid size W*H (--grid).  In this case the
    output video length will be '# seeds/(w*h)*w_frames' frames.
    """

    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cuda')
    with dnnlib.util.open_url(network_pkl) as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device)  # type: ignore

    gen_interp_video(G=G,
                     mp4=output,
                     bitrate='12M',
                     grid_dims=grid,
                     num_keyframes=num_keyframes,
                     w_frames=w_frames,
                     seeds=seeds,
                     shuffle_seed=shuffle_seed,
                     psi=truncation_psi,
                     class_idx=class_idx)
def main(pkl: str,
         psi: float,
         radius_large: float,
         radius_small: float,
         step1: float,
         step2: float,
         seed: Optional[int],
         video_length: float = 1.0,
         size: int = None,
         seeds: int = None,
         scale_type: str = 'pad'):

    if (size):
        print('render custom size: ', size)
        print('padding method:', scale_type)
        custom = True
    else:
        custom = False

    G_kwargs = dnnlib.EasyDict()
    G_kwargs.size = size
    G_kwargs.scale_type = scale_type

    print('Loading networks from "%s"...' % pkl)
    device = torch.device('cuda')
    with dnnlib.util.open_url(pkl) as f:
        G = legacy.load_network_pkl(f, custom=custom, **G_kwargs)['G_ema'].to(
            device)  # type: ignore

    frames = generate_from_generator_adaptive(psi, radius_large, radius_small,
                                              step1, step2, video_length, seed,
                                              seeds, G, device)
    frames = moviepy.editor.ImageSequenceClip(frames, fps=30)

    # Generate video at the current date and timestamp
    timestamp = datetime.now().strftime("%d-%m-%Y-%I-%M-%S-%p")
    mp4_file = './circular-' + timestamp + '.mp4'
    mp4_codec = 'libx264'
    mp4_bitrate = '15M'
    mp4_fps = 24  # 20

    frames.write_videofile(mp4_file,
                           fps=mp4_fps,
                           codec=mp4_codec,
                           bitrate=mp4_bitrate)
def generate_video(network_pkl, seeds, steps, fps, output_filename, outdir):
    if len(seeds) < 2:
        print('Need more than one seed')
        return

    if steps < 1:
        print('At least one step required')
        return

    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cuda')
    with dnnlib.util.open_url(network_pkl) as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device)  # type: ignore
    label = torch.zeros([1, G.c_dim], device=device)

    os.makedirs(outdir, exist_ok=True)

    # Generate the images for the video.
    idx = 0
    for i in range(len(seeds) - 1):
        v1 = seed2vec(G, seeds[i])
        v2 = seed2vec(G, seeds[i + 1])

        diff = v2 - v1
        step = diff / steps
        current = v1.copy()

        for j in tqdm(range(steps), desc=f"Seed {seeds[i]}"):
            current = current + step
            z = torch.from_numpy(current).to(device)
            img = G(z, label, noise_mode='const')
            img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
                torch.uint8)

            PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(
                os.path.join(outdir, f'frame-{idx}.png'))
            idx += 1

    cmd = 'ffmpeg -y -i {}/frame-%d.png -r {} -c:v libx264 -crf 18 -preset medium -pix_fmt yuv420p {}'.format(
        outdir, fps, output_filename)
    os.system(cmd)
예제 #15
0
def generate_video(network_pkl, seed, fps, output_filename, wav_filename, outdir):
    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cuda')
    with dnnlib.util.open_url(network_pkl) as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
    label = torch.zeros([1, G.c_dim], device=device)

    os.makedirs(outdir, exist_ok=True)

    beats, total_len = get_beats(wav_filename)
    beats = [int(np.round(fps * x)) for x in beats[::1]]

    # Add the first and the last frames to beats array
    if 0 not in beats:
        beats = [0] + beats
    last = int(np.round(fps * total_len))
    if last not in beats:
        beats.append(last)

    idx = 0
    for i in tqdm(range(len(beats)-1)):
        v1 = seed2vec(G, seed)
        seed += 1
        v2 = seed2vec(G, seed)
        diff = v2 - v1

        n_frames = beats[i+1] - beats[i]

        x = np.linspace(0, np.pi, n_frames)
        y = np.cumsum(1 - np.sin(x))
        y /= y[-1]

        for j in range(n_frames):
            current = v1 + diff * y[j]
            z = torch.from_numpy(current).to(device)
            img = G(z, label, noise_mode='const')
            img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
            PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(os.path.join(outdir, 'frame-{}.png'.format(beats[i] + j)))

    cmd = 'ffmpeg -y -i {}/frame-%d.png -i {} -r {} -c:v libx264 -crf 18 -preset medium -pix_fmt yuv420p {}'.format(outdir, wav_filename, fps, output_filename)
    os.system(cmd)
예제 #16
0
def run_approach(
    network_pkl: str,
    outdir: str,
    save_video: bool,
    save_ws: bool,
    seed: int,
    num_steps: int,
    text: str,
    lr: float,
    inf: float,
    nf: float,
    w: str,
    psi: float,
    noise_opt: bool
):
    """Descend on StyleGAN2 w vector value using CLIP, tuning an image with given text prompt.

    Example:

    \b
    python3 approach.py --network network-snapshot-ffhq.pkl --outdir project --num-steps 100  \\
    --text 'an image of a girl with a face resembling Paul Krugman' --psi 0.8 --seed 12345

    """

    #seed = 1
    np.random.seed(1)
    torch.manual_seed(1)

    local_args = dict(locals())
    params = []
    for x in local_args:
        #if x != 'G' and x != 'device':
        #print(x,':',local_args[x])
        params.append({x:local_args[x]})
    #print(json.dumps(params))
    hashname = str(hashlib.sha1((json.dumps(params)).encode('utf-16be')).hexdigest() )
    print('run hash', hashname)

    ws = None
    if w is not None:
        print ('loading w from file', w, 'ignoring seed and psi')
        ws = np.load(w)['w']

    # take off
    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cuda')
    with dnnlib.util.open_url(network_pkl) as fp:
        G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore

    # approach
    projected_w_steps = approach(
        G,
        num_steps=num_steps,
        device=device,
        initial_learning_rate = lr,
        psi = psi,
        seed = seed,
        initial_noise_factor = inf,
        noise_floor = nf,
        text = text,
        ws = ws,
        noise_opt = noise_opt
    )

    # save video
    os.makedirs(outdir, exist_ok=True)
    if save_video:
        video = imageio.get_writer(f'{outdir}/out-{hashname}.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
        print (f'Saving optimization progress video "{outdir}/out-{hashname}.mp4"')
        for projected_w in projected_w_steps:
            synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
            synth_image = (synth_image + 1) * (255/2)
            synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
            video.append_data(np.concatenate([synth_image], axis=1))
        video.close()

    # save ws
    if save_ws:
        print ('Saving optimization progress ws')
        step = 0
        for projected_w in projected_w_steps:
            np.savez(f'{outdir}/w-{hashname}-{step}.npz', w=projected_w.unsqueeze(0).cpu().numpy())
            step+=1

    # save the result and the final w
    print ('Saving finals')
    projected_w = projected_w_steps[-1]
    synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
    synth_image = (synth_image + 1) * (255/2)
    synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
    PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/out-{hashname}.png')
    np.savez(f'{outdir}/w-{hashname}-final.npz', w=projected_w.unsqueeze(0).cpu().numpy())


    # save params
    with open(f'{outdir}/params-{hashname}.txt', 'w') as outfile:
        json.dump(params, outfile)
예제 #17
0
    device = torch.device('cuda')
    eigvec = torch.load(args.factor)["eigvec"].to(device)
    index = args.index
    seeds = args.seeds

    custom = False

    G_kwargs = dnnlib.EasyDict()
    G_kwargs.size = None 
    G_kwargs.scale_type = 'symm'
    
    print('Loading networks from "%s"...' % args.ckpt)
    device = torch.device('cuda')
    with dnnlib.util.open_url(args.ckpt) as f:
        G = legacy.load_network_pkl(f, custom=custom, **G_kwargs)['G_ema'].to(device) # type: ignore


    if not os.path.exists(args.output):
      os.makedirs(args.output)

    label = torch.zeros([1, G.c_dim], device=device) # assume no class label
    noise_mode = "const" # default
    truncation_psi = args.truncation

    latents = []
    mode = "random"
    log_str = ""

    index_list_of_eigenvalues = []
예제 #18
0
def generate():
    os.makedirs(a.out_dir, exist_ok=True)
    np.random.seed(seed=696)
    device = torch.device('cuda')

    # setup generator
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.verbose = a.verbose
    Gs_kwargs.size = a.size
    Gs_kwargs.scale_type = a.scale_type

    # mask/blend latents with external latmask or by splitting the frame
    if a.latmask is None:
        nHW = [int(s) for s in a.nXY.split('-')][::-1]
        assert len(nHW) == 2, ' Wrong count nXY: %d (must be 2)' % len(nHW)
        n_mult = nHW[0] * nHW[1]
        if a.verbose is True and n_mult > 1:
            print(' Latent blending w/split frame %d x %d' % (nHW[1], nHW[0]))
        lmask = np.tile(np.asarray([[[[1]]]]), (1, n_mult, 1, 1))
        Gs_kwargs.countHW = nHW
        Gs_kwargs.splitfine = a.splitfine
    else:
        if a.verbose is True: print(' Latent blending with mask', a.latmask)
        n_mult = 2
        if os.path.isfile(a.latmask):  # single file
            lmask = np.asarray([[img_read(a.latmask)[:, :, 0] / 255.]
                                ])  # [h,w]
        elif os.path.isdir(a.latmask):  # directory with frame sequence
            lmask = np.asarray([[
                img_read(f)[:, :, 0] / 255. for f in img_list(a.latmask)
            ]])  # [h,w]
        else:
            print(' !! Blending mask not found:', a.latmask)
            exit(1)
        lmask = np.concatenate((lmask, 1 - lmask), 1)  # [frm,2,h,w]
    lmask = torch.from_numpy(lmask).to(device)

    # load base or custom network
    pkl_name = osp.splitext(a.model)[0]
    if '.pkl' in a.model.lower():
        custom = False
        print(' .. Gs from pkl ..', basename(a.model))
    else:
        custom = True
        print(' .. Gs custom ..', basename(a.model))
    with dnnlib.util.open_url(pkl_name + '.pkl') as f:
        Gs = legacy.load_network_pkl(f,
                                     custom=custom, **Gs_kwargs)['G_ema'].to(
                                         device)  # type: ignore

    if a.verbose is True: print(' out shape', Gs.output_shape[1:])

    if a.verbose is True: print(' making timeline..')
    lats = []  # list of [frm,1,512]
    for i in range(n_mult):
        lat_tmp = latent_anima((1, Gs.z_dim),
                               a.frames,
                               a.fstep,
                               cubic=a.cubic,
                               gauss=a.gauss,
                               verbose=False)  # [frm,1,512]
        lats.append(lat_tmp)  # list of [frm,1,512]
    latents = np.concatenate(lats, 1)  # [frm,X,512]
    print(' latents', latents.shape)
    latents = torch.from_numpy(latents).to(device)
    frame_count = latents.shape[0]

    # distort image by tweaking initial const layer
    if a.digress > 0:
        try:
            init_res = Gs.init_res
        except:
            init_res = (4, 4)  # default initial layer size
        dconst = []
        for i in range(n_mult):
            dc_tmp = a.digress * latent_anima([1, Gs.z_dim, *init_res],
                                              a.frames,
                                              a.fstep,
                                              cubic=True,
                                              verbose=False)
            dconst.append(dc_tmp)
        dconst = np.concatenate(dconst, 1)
    else:
        dconst = np.zeros([frame_count, 1, 1, 1, 1])
    dconst = torch.from_numpy(dconst).to(device)

    # labels / conditions
    label_size = Gs.c_dim
    if label_size > 0:
        labels = torch.zeros((frame_count, n_mult, label_size),
                             device=device)  # [frm,X,lbl]
        if a.labels is None:
            label_ids = []
            for i in range(n_mult):
                label_ids.append(random.randint(0, label_size - 1))
        else:
            label_ids = [int(x) for x in a.labels.split('-')]
            label_ids = label_ids[:n_mult]  # ensure we have enough labels
        for i, l in enumerate(label_ids):
            labels[:, i, l] = 1
    else:
        labels = [None]

    # generate images from latent timeline
    pbar = ProgressBar(frame_count)
    for i in range(frame_count):

        latent = latents[i]  # [X,512]
        label = labels[i % len(labels)]
        latmask = lmask[i %
                        len(lmask)] if lmask is not None else [None]  # [X,h,w]
        dc = dconst[i % len(dconst)]  # [X,512,4,4]

        # generate multi-latent result
        if custom:
            output = Gs(latent,
                        label,
                        latmask,
                        dc,
                        truncation_psi=a.trunc,
                        noise_mode='const')
        else:
            output = Gs(latent,
                        label,
                        truncation_psi=a.trunc,
                        noise_mode='const')
        output = (output.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
            torch.uint8).cpu().numpy()

        # save image
        ext = 'png' if output.shape[3] == 4 else 'jpg'
        filename = osp.join(a.out_dir, "%06d.%s" % (i, ext))
        imsave(filename, output[0])
        pbar.upd()

    # convert latents to dlatents, save them
    if a.save_lat is True:
        latents = latents.squeeze(1)  # [frm,512]
        dlatents = Gs.mapping(latents, label)  # [frm,18,512]
        if a.size is None: a.size = [''] * 2
        filename = '{}-{}-{}.npy'.format(basename(a.model), a.size[1],
                                         a.size[0])
        filename = osp.join(osp.dirname(a.out_dir), filename)
        dlatents = dlatents.cpu().numpy()
        np.save(filename, dlatents)
        print('saved dlatents', dlatents.shape, 'to', filename)
예제 #19
0
def main():
    os.makedirs(a.out_dir, exist_ok=True)
    device = torch.device('cuda')

    # setup generator
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.verbose = a.verbose
    Gs_kwargs.size = a.size
    Gs_kwargs.scale_type = a.scale_type

    # load base or custom network
    pkl_name = osp.splitext(a.model)[0]
    if '.pkl' in a.model.lower():
        custom = False
        print(' .. Gs from pkl ..', basename(a.model))
    else:
        custom = True
        print(' .. Gs custom ..', basename(a.model))
    with dnnlib.util.open_url(pkl_name + '.pkl') as f:
        Gs = legacy.load_network_pkl(f,
                                     custom=custom, **Gs_kwargs)['G_ema'].to(
                                         device)  # type: ignore

    dlat_shape = (1, Gs.num_ws, Gs.w_dim)  # [1,18,512]

    # read saved latents
    if a.dlatents is not None and osp.isfile(a.dlatents):
        key_dlatents = load_latents(a.dlatents)
        if len(key_dlatents.shape) == 2:
            key_dlatents = np.expand_dims(key_dlatents, 0)
    elif a.dlatents is not None and osp.isdir(a.dlatents):
        # if a.dlatents.endswith('/') or a.dlatents.endswith('\\'): a.dlatents = a.dlatents[:-1]
        key_dlatents = []
        npy_list = file_list(a.dlatents, 'npy')
        for npy in npy_list:
            key_dlatent = load_latents(npy)
            if len(key_dlatent.shape) == 2:
                key_dlatent = np.expand_dims(key_dlatent, 0)
            key_dlatents.append(key_dlatent)
        key_dlatents = np.concatenate(key_dlatents)  # [frm,18,512]
    else:
        print(' No input dlatents found')
        exit()
    key_dlatents = key_dlatents[:, np.newaxis]  # [frm,1,18,512]
    print(' key dlatents', key_dlatents.shape)

    # replace higher layers with single (style) latent
    if a.style_dlat is not None:
        print(' styling with dlatent', a.style_dlat)
        style_dlatent = load_latents(a.style_dlat)
        while len(style_dlatent.shape) < 4:
            style_dlatent = np.expand_dims(style_dlatent, 0)
        # try replacing 5 by other value, less than Gs.num_ws
        key_dlatents[:, :, range(5, Gs.num_ws
                                 ), :] = style_dlatent[:, :,
                                                       range(5, Gs.num_ws), :]

    frames = key_dlatents.shape[0] * a.fstep

    dlatents = latent_anima(dlat_shape,
                            frames,
                            a.fstep,
                            key_latents=key_dlatents,
                            cubic=a.cubic,
                            verbose=True)  # [frm,1,512]
    print(' dlatents', dlatents.shape)
    frame_count = dlatents.shape[0]
    dlatents = torch.from_numpy(dlatents).to(device)

    # distort image by tweaking initial const layer
    if a.digress > 0:
        try:
            init_res = Gs.init_res
        except Exception:
            init_res = (4, 4)  # default initial layer size
        dconst = a.digress * latent_anima([1, Gs.z_dim, *init_res],
                                          frame_count,
                                          a.fstep,
                                          cubic=True,
                                          verbose=False)
    else:
        dconst = np.zeros([frame_count, 1, 1, 1, 1])
    dconst = torch.from_numpy(dconst).to(device)

    # generate images from latent timeline
    pbar = ProgressBar(frame_count)
    for i in range(frame_count):

        # generate multi-latent result
        if custom:
            output = Gs.synthesis(dlatents[i],
                                  None,
                                  dconst[i],
                                  noise_mode='const')
        else:
            output = Gs.synthesis(dlatents[i], noise_mode='const')
        output = (output.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
            torch.uint8).cpu().numpy()

        ext = 'png' if output.shape[3] == 4 else 'jpg'
        filename = osp.join(a.out_dir, "%06d.%s" % (i, ext))
        imsave(filename, output[0])
        pbar.upd()
예제 #20
0
def generate_images(
    ctx: click.Context,
    easing: str,
    interpolation: str,
    increment: Optional[float],
    network_pkl: str,
    process: str,
    random_seed: Optional[int],
    diameter: Optional[float],
    seeds: Optional[List[int]],
    space: str,
    fps: Optional[int],
    frames: Optional[int],
    truncation_psi: float,
    noise_mode: str,
    outdir: str,
    class_idx: Optional[int],
    projected_w: Optional[str],
    start: Optional[float],
    stop: Optional[float],
):
    """Generate images using pretrained network pickle.

    Examples:

    \b
    # Generate curated MetFaces images without truncation (Fig.10 left)
    python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\
        --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl

    \b
    # Generate uncurated MetFaces images with truncation (Fig.12 upper left)
    python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\
        --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl

    \b
    # Generate class conditional CIFAR-10 images (Fig.17 left, Car)
    python generate.py --outdir=out --seeds=0-35 --class=1 \\
        --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl

    \b
    # Render an image from projected W
    python generate.py --outdir=out --projected_w=projected_w.npz \\
        --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
    """

    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cuda')
    with dnnlib.util.open_url(network_pkl) as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device)  # type: ignore

    os.makedirs(outdir, exist_ok=True)

    # Synthesize the result of a W projection.
    if (process == 'image') and projected_w is not None:
        if seeds is not None:
            print('Warning: --seeds is ignored when using --projected-w')
        print(f'Generating images from projected W "{projected_w}"')
        ws = np.load(projected_w)['w']
        ws = torch.tensor(ws, device=device)  # pylint: disable=not-callable
        assert ws.shape[1:] == (G.num_ws, G.w_dim)
        for idx, w in enumerate(ws):
            img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode)
            img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
                torch.uint8)
            img = PIL.Image.fromarray(
                img[0].cpu().numpy(),
                'RGB').save(f'{outdir}/proj{idx:02d}.png')
        return

    # Labels.
    label = torch.zeros([1, G.c_dim], device=device)
    if G.c_dim != 0:
        if class_idx is None:
            ctx.fail(
                'Must specify class label with --class when using a conditional network'
            )
        label[:, class_idx] = 1
    else:
        if class_idx is not None:
            print(
                'warn: --class=lbl ignored when running on an unconditional network'
            )

    if (process == 'image'):
        if seeds is None:
            ctx.fail('--seeds option is required when not using --projected-w')

        # Generate images.
        for seed_idx, seed in enumerate(seeds):
            print('Generating image for seed %d (%d/%d) ...' %
                  (seed, seed_idx, len(seeds)))
            z = torch.from_numpy(
                np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
            img = G(z,
                    label,
                    truncation_psi=truncation_psi,
                    noise_mode=noise_mode)
            img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
                torch.uint8)
            PIL.Image.fromarray(img[0].cpu().numpy(),
                                'RGB').save(f'{outdir}/seed{seed:04d}.png')

    elif (process == 'interpolation'):
        # create path for frames
        dirpath = os.path.join(outdir, 'frames')
        os.makedirs(dirpath, exist_ok=True)

        # autogenerate video name: not great!
        if seeds is not None:
            seedstr = '_'.join([str(seed) for seed in seeds])
            vidname = f'{process}-{interpolation}-seeds_{seedstr}-{fps}fps'
        elif (interpolation == 'noiseloop' or 'circularloop'):
            vidname = f'{process}-{interpolation}-{diameter}dia-seed_{random_seed}-{fps}fps'

        interpolate(G, device, projected_w, seeds, random_seed, space,
                    truncation_psi, label, frames, noise_mode, dirpath,
                    interpolation, easing, diameter)

        # convert to video
        cmd = f'ffmpeg -y -r {fps} -i {dirpath}/frame%04d.png -vcodec libx264 -pix_fmt yuv420p {outdir}/{vidname}.mp4'
        subprocess.call(cmd, shell=True)

    elif (process == 'truncation'):
        if seeds is None or (len(seeds) > 1):
            ctx.fail('truncation requires a single seed value')

        # create path for frames
        dirpath = os.path.join(outdir, 'frames')
        os.makedirs(dirpath, exist_ok=True)

        #vidname
        seed = seeds[0]
        vidname = f'{process}-seed_{seed}-start_{start}-stop_{stop}-inc_{increment}-{fps}fps'

        # generate frames
        truncation_traversal(G, device, seeds, label, start, stop, increment,
                             noise_mode, dirpath)

        # convert to video
        cmd = f'ffmpeg -y -r {fps} -i {dirpath}/frame%04d.png -vcodec libx264 -pix_fmt yuv420p {outdir}/{vidname}.mp4'
        subprocess.call(cmd, shell=True)
def run_projection(
    network_pkl: str,
    target_fname: str,
    outdir: str,
    save_video: bool,
    seed: int,
    num_steps: int
):
    """Project given image to the latent space of pretrained network pickle.

    Examples:

    \b
    python projector.py --outdir=out --target=~/mytargetimg.png \\
        --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
    """
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Load networks.
    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cuda')
    with dnnlib.util.open_url(network_pkl) as fp:
        G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore

    # Load target image.
    target_pil = PIL.Image.open(target_fname).convert('RGB')
    w, h = target_pil.size
    s = min(w, h)
    target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
    target_pil = target_pil.resize((G.img_resolution_w, G.img_resolution_h), PIL.Image.LANCZOS)
    target_uint8 = np.array(target_pil, dtype=np.uint8)

    # Optimize projection.
    start_time = perf_counter()
    projected_w_steps = project(
        G,
        target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
        num_steps=num_steps,
        device=device,
        verbose=True
    )
    print (f'Elapsed: {(perf_counter()-start_time):.1f} s')

    # Render debug output: optional video and projected image and W vector.
    os.makedirs(outdir, exist_ok=True)
    if save_video:
        video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
        print (f'Saving optimization progress video "{outdir}/proj.mp4"')
        for projected_w in projected_w_steps:
            synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
            synth_image = (synth_image + 1) * (255/2)
            synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
            video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
        video.close()

    # Save final projected frame and W vector.
    target_pil.save(f'{outdir}/target.png')
    projected_w = projected_w_steps[-1]
    synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
    synth_image = (synth_image + 1) * (255/2)
    synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
    PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png')
    np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())
def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
    """Calculate quality metrics for previous training run or pretrained network pickle.

    Examples:

    \b
    # Previous training run: look up options automatically, save result to JSONL file.
    python calc_metrics.py --metrics=pr50k3_full \\
        --network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl

    \b
    # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
    python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\
        --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl

    Available metrics:

    \b
      ADA paper:
        fid50k_full  Frechet inception distance against the full dataset.
        kid50k_full  Kernel inception distance against the full dataset.
        pr50k3_full  Precision and recall againt the full dataset.
        is50k        Inception score for CIFAR-10.

    \b
      StyleGAN and StyleGAN2 papers:
        fid50k       Frechet inception distance against 50k real images.
        kid50k       Kernel inception distance against 50k real images.
        pr50k3       Precision and recall against 50k real images.
        ppl2_wend    Perceptual path length in W at path endpoints against full image.
        ppl_zfull    Perceptual path length in Z for full paths against cropped image.
        ppl_wfull    Perceptual path length in W for full paths against cropped image.
        ppl_zend     Perceptual path length in Z at path endpoints against cropped image.
        ppl_wend     Perceptual path length in W at path endpoints against cropped image.
    """
    dnnlib.util.Logger(should_flush=True)

    # Validate arguments.
    args = dnnlib.EasyDict(metrics=metrics,
                           num_gpus=gpus,
                           network_pkl=network_pkl,
                           verbose=verbose)
    if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
        ctx.fail(
            '\n'.join(['--metrics can only contain the following values:'] +
                      metric_main.list_valid_metrics()))
    if not args.num_gpus >= 1:
        ctx.fail('--gpus must be at least 1')

    # Load network.
    if not dnnlib.util.is_url(
            network_pkl,
            allow_file_urls=True) and not os.path.isfile(network_pkl):
        ctx.fail('--network must point to a file or URL')
    if args.verbose:
        print(f'Loading network from "{network_pkl}"...')
    with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
        network_dict = legacy.load_network_pkl(f)
        args.G = network_dict['G_ema']  # subclass of torch.nn.Module

    # Initialize dataset options.
    if data is not None:
        args.dataset_kwargs = dnnlib.EasyDict(
            class_name='training.dataset.ImageFolderDataset', path=data)
    elif network_dict['training_set_kwargs'] is not None:
        args.dataset_kwargs = dnnlib.EasyDict(
            network_dict['training_set_kwargs'])
    else:
        ctx.fail('Could not look up dataset options; please specify --data')

    # Finalize dataset options.
    args.dataset_kwargs.resolution_h = args.G.img_resolution_h
    args.dataset_kwargs.resolution_w = args.G.img_resolution_w
    args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
    if mirror is not None:
        args.dataset_kwargs.xflip = mirror

    # Print dataset options.
    if args.verbose:
        print('Dataset options:')
        print(json.dumps(args.dataset_kwargs, indent=2))

    # Locate run dir.
    args.run_dir = None
    if os.path.isfile(network_pkl):
        pkl_dir = os.path.dirname(network_pkl)
        if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
            args.run_dir = pkl_dir

    # Launch processes.
    if args.verbose:
        print('Launching processes...')
    torch.multiprocessing.set_start_method('spawn')
    with tempfile.TemporaryDirectory() as temp_dir:
        if args.num_gpus == 1:
            subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
        else:
            torch.multiprocessing.spawn(fn=subprocess_fn,
                                        args=(args, temp_dir),
                                        nprocs=args.num_gpus)
예제 #23
0
def generate_interpolation_video(network_pkl=None,
                                 grid_size=[1, 1],
                                 png_sequence=False,
                                 image_zoom=1,
                                 duration_sec=60.0,
                                 smoothing_sec=1.0,
                                 truncation_psi=1,
                                 noise_mode=False,
                                 filename=None,
                                 mp4_fps=30,
                                 mp4_codec='libx264',
                                 mp4_bitrate='16M',
                                 random_seed=1000,
                                 outdir='./videos'):

    if network_pkl == None:
        print('ERROR: Please enter pkl path.')
        sys.exit(1)
    num_frames = int(np.rint(duration_sec * mp4_fps))
    random_state = np.random.RandomState(random_seed)
    if filename is None:
        filename = get_id_string_for_network_pkl(network_pkl) + '-seed-' + str(
            random_seed)

    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cuda')
    with dnnlib.util.open_url(network_pkl) as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device)  # type: ignore

    print('Generating latent vectors...')
    shape = [num_frames, np.prod(grid_size)
             ] + [G.z_dim]  # [frame, image, channel, component]
    print(shape)
    all_latents = random_state.randn(*shape).astype(np.float32)
    all_latents = scipy.ndimage.gaussian_filter(
        all_latents, [smoothing_sec * mp4_fps] + [0] * (len(shape) - 1),
        mode='wrap')
    all_latents /= np.sqrt(np.mean(np.square(all_latents)))

    print("Rendering...\ntruncation_psi =", truncation_psi, ", noise_mode =",
          noise_mode)
    os.makedirs(outdir, exist_ok=True)

    ###
    ### this is the moviepy implementation of rendering
    ### it has a nice progress bar
    ### there is an imageio implementation commented out below as well
    ###

    # Frame generation func for moviepy.
    # def make_frame(t):
    #     frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1))
    #     z = torch.from_numpy(all_latents[frame_idx]).to(device)
    #     label = np.zeros([z.shape[0], 0], np.float32)
    #     images = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
    #     images = (images * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
    #     grid = create_image_grid(images, grid_size)
    #     if image_zoom > 1:
    #         grid = scipy.ndimage.zoom(grid, [image_zoom, image_zoom, 1], order=0)
    #     if grid.shape[2] == 1:
    #         grid = grid.repeat(3, 2) # grayscale => RGB
    #     return grid

    # # Generate video.
    # import moviepy.editor # pip install moviepy
    # moviepy.editor.VideoClip(make_frame, duration=duration_sec).write_videofile(os.path.join(outdir, filename + ".mp4"), fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate)

    ###
    ### this is an alternative imageio implementation of rendering
    ### I like moviepy more cause it has a nice progress bar
    ### not sure which one is more "performant", feel free to experiment
    ###

    import imageio  # pip install imageio
    video = imageio.get_writer(f'{outdir}/seed{random_seed:04d}.mp4',
                               mode='I',
                               fps=mp4_fps,
                               codec=mp4_codec,
                               bitrate=mp4_bitrate)
    for frame_idx in range(num_frames):
        z = torch.from_numpy(all_latents[frame_idx]).to(device)
        label = torch.zeros([1, G.c_dim], device=device)
        img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
        img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
        grid = create_image_grid(img, grid_size)
        if image_zoom > 1:
            grid = scipy.ndimage.zoom(grid, [image_zoom, image_zoom, 1],
                                      order=0)
        video.append_data(grid)
    video.close()
예제 #24
0
def training_loop(
        run_dir='.',  # Output directory.
        training_set_kwargs={},  # Options for training set.
        data_loader_kwargs={},  # Options for torch.utils.data.DataLoader.
        G_kwargs={},  # Options for generator network.
        D_kwargs={},  # Options for discriminator network.
        G_opt_kwargs={},  # Options for generator optimizer.
        D_opt_kwargs={},  # Options for discriminator optimizer.
        augment_kwargs=None,  # Options for augmentation pipeline. None = disable.
        loss_kwargs={},  # Options for loss function.
        savenames=None,  #
        random_seed=0,  # Global random seed.
        num_gpus=1,  # Number of GPUs participating in the training.
        rank=0,  # Rank of the current process in [0, num_gpus[.
        batch_size=4,  # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
        batch_gpu=4,  # Number of samples processed at a time by one GPU.
        ema_kimg=10,  # Half-life of the exponential moving average (EMA) of generator weights.
        ema_rampup=None,  # EMA ramp-up coefficient.
        G_reg_interval=4,  # How often to perform regularization for G? None = disable lazy regularization.
        D_reg_interval=16,  # How often to perform regularization for D? None = disable lazy regularization.
        augment_p=0,  # Initial value of augmentation probability.
        ada_target=None,  # ADA target value. None = fixed p.
        ada_interval=4,  # How often to perform ADA adjustment?
        ada_kimg=500,
        # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
        total_kimg=25000,  # Total length of the training, measured in thousands of real images.
        kimg_per_tick=4,  # Progress snapshot interval.
        image_snapshot_ticks=50,  # How often to save image snapshots? None = disable.
        network_snapshot_ticks=50,  # How often to save network snapshots? None = disable.
        resume_pkl=None,  # Network pickle to resume training from.
        resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
        cudnn_benchmark=True,  # Enable torch.backends.cudnn.benchmark?
        abort_fn=None,
        # Callback function for determining whether to abort training. Must return consistent results across ranks.
        progress_fn=None,  # Callback function for updating training progress. Called for all ranks.
        **kwargs
):
    # Initialize.
    start_time = time.time()
    device = torch.device('cuda', rank)
    np.random.seed(random_seed * num_gpus + rank)
    torch.manual_seed(random_seed * num_gpus + rank)
    torch.backends.cudnn.benchmark = cudnn_benchmark  # Improves training speed.
    conv2d_gradfix.enabled = True  # Improves training speed.
    grid_sample_gradfix.enabled = True  # Avoids errors with the augmentation pipe.
    run_id = str(int(time.time()))[3:-2]
    hostname = str(socket.gethostname())

    # Load training set.
    # if rank == 0:
    # print('Loading training set...')
    training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs)  # subclass of training.dataset.Dataset
    training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus,
                                                seed=random_seed)
    training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler,
                                                             batch_size=batch_size // num_gpus, **data_loader_kwargs))
    if rank == 0:
        # print('Num images: ', len(training_set))
        print('Image shape:', training_set.image_shape)
        print('Label shape:', training_set.label_shape)
    ext = 'png' if training_set.image_shape[0] == 4 else 'jpg'  # !!!

    # Construct networks.
    if rank == 0:
        print('Constructing networks...')
    common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution,
                         img_channels=training_set.num_channels)
    G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(
        device)  # subclass of torch.nn.Module
    D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(
        device)  # subclass of torch.nn.Module
    G_ema = copy.deepcopy(G).eval()

    # Resume from existing pickle.
    if (resume_pkl is not None) and (rank == 0):
        # !!!
        if os.path.isdir(resume_pkl):
            resume_pkl, resume_kimg = locate_latest_pkl(resume_pkl)
        print(' Resuming from "%s", kimg %.3g' % (resume_pkl, resume_kimg))
        with dnnlib.util.open_url(resume_pkl) as f:
            resume_data = legacy.load_network_pkl(f)
        for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
            misc.copy_params_and_buffers(resume_data[name], module, require_all=False)

    # Print network summary tables.
    if rank == 0:
        # z = torch.empty([batch_gpu, G.z_dim], device=device)
        # c = torch.empty([batch_gpu, G.c_dim], device=device)
        # img = misc.print_module_summary(G, [z, c])
        # misc.print_module_summary(D, [img, c])
        pass

    # Setup augmentation.
    # if rank == 0:
    # print('Setting up augmentation...')
    augment_pipe = None
    ada_stats = None
    if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None):
        augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(
            device)  # subclass of torch.nn.Module
        augment_pipe.p.copy_(torch.as_tensor(augment_p))
        if ada_target is not None:
            ada_stats = training_stats.Collector(regex='Loss/signs/real')

    # Distribute across GPUs.
    # if rank == 0:
    # print(f'Distributing across {num_gpus} GPUs...')
    ddp_modules = dict()
    for name, module in [('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema),
                         ('augment_pipe', augment_pipe)]:
        if (num_gpus > 1) and (module is not None) and len(list(module.parameters())) != 0:
            module.requires_grad_(True)
            module = torch.nn.parallel.DistributedDataParallel(module, device_ids=[device], broadcast_buffers=False)
            module.requires_grad_(False)
        if name is not None:
            ddp_modules[name] = module

    # Setup training phases.
    # if rank == 0:
    # print('Setting up training phases...')
    loss = dnnlib.util.construct_class_by_name(device=device, **ddp_modules,
                                               **loss_kwargs)  # subclass of training.loss.Loss
    phases = []
    for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval),
                                                   ('D', D, D_opt_kwargs, D_reg_interval)]:
        if reg_interval is None:
            opt = dnnlib.util.construct_class_by_name(params=module.parameters(),
                                                      **opt_kwargs)  # subclass of torch.optim.Optimizer
            phases += [dnnlib.EasyDict(name=name + 'both', module=module, opt=opt, interval=1)]
        else:  # Lazy regularization.
            mb_ratio = reg_interval / (reg_interval + 1)
            opt_kwargs = dnnlib.EasyDict(opt_kwargs)
            opt_kwargs.lr = opt_kwargs.lr * mb_ratio
            opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas]
            opt = dnnlib.util.construct_class_by_name(module.parameters(),
                                                      **opt_kwargs)  # subclass of torch.optim.Optimizer
            phases += [dnnlib.EasyDict(name=name + 'main', module=module, opt=opt, interval=1)]
            phases += [dnnlib.EasyDict(name=name + 'reg', module=module, opt=opt, interval=reg_interval)]
    for phase in phases:
        phase.start_event = None
        phase.end_event = None
        if rank == 0:
            phase.start_event = torch.cuda.Event(enable_timing=True)
            phase.end_event = torch.cuda.Event(enable_timing=True)

    # Export sample images.
    grid_size = None
    grid_z = None
    grid_c = None
    if rank == 0:
        # print('Exporting sample images...')
        grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set)
        save_image_grid(images, os.path.join(run_dir, '_reals.' + ext), drange=[0, 255], grid_size=grid_size)
        grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu)
        grid_c = torch.from_numpy(labels).to(device).split(batch_gpu)

    # Initialize logs.
    # if rank == 0:
    # print('Initializing logs...')
    stats_collector = training_stats.Collector(regex='.*')

    # Only upload data once, not for every process
    if rank == 0:
        url = f'https://mnr6yzqr22jgywm-adw2.adb.eu-frankfurt-1.oraclecloudapps.com/ords/thesisproject/sg_d/' \
              f'report/{run_id}/{hostname}'
        metadata = training_set_kwargs
        more_metadata = [
            {'run_dir': run_dir},
            {'data_loader_kwargs': data_loader_kwargs},
            {'G_kwargs': G_kwargs},
            {'D_kwargs': D_kwargs},
            {'G_opt_kwargs': G_opt_kwargs},
            {'D_opt_kwargs': D_opt_kwargs},
            {'augment_kwargs': augment_kwargs},
            {'loss_kwargs': loss_kwargs},
            {'savenames': savenames},
            {'random_seed': random_seed},
            {'num_gpus': num_gpus},
            {'batch_size': batch_size},
            {'batch_gpu': batch_gpu},
            {'ema_kimg': ema_kimg},
            {'ema_rampup': ema_rampup},
            {'G_reg_interval': G_reg_interval},
            {'D_reg_interval': D_reg_interval},
            {'augment_p': augment_p},
            {'ada_target': ada_target},
            {'ada_interval': ada_interval},
            {'ada_kimg': ada_kimg},
            {'total_kimg': total_kimg},
            {'kimg_per_tick': kimg_per_tick},
            {'image_snapshot_ticks': image_snapshot_ticks},
            {'network_snapshot_ticks': network_snapshot_ticks},
            {'resume_pkl': resume_pkl},
            {'resume_kimg': resume_kimg},
            {'cudnn_benchmark': cudnn_benchmark}
        ]
        for d in more_metadata:
            metadata.update(d)
        requests.post(url, data=json.dumps(training_set_kwargs))

    stats_jsonl = None
    # stats_tfevents = None
    if rank == 0:
        stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt')

    # Train.
    if rank == 0:
        print(f'Training for {total_kimg} kimg...')
        print()
    cur_nimg = 0
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    # maintenance_time = tick_start_time - start_time
    batch_idx = 0
    if progress_fn is not None:
        progress_fn(0, total_kimg)
    while True:

        # Fetch training data.
        with torch.autograd.profiler.record_function('data_fetch'):
            phase_real_img, phase_real_c = next(training_set_iterator)
            phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
            phase_real_c = phase_real_c.to(device).split(batch_gpu)
            all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device)
            all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)]
            all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in
                         range(len(phases) * batch_size)]
            all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device)
            all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)]

        # Execute training phases.
        for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c):
            if batch_idx % phase.interval != 0:
                continue

            # Initialize gradient accumulation.
            if phase.start_event is not None:
                phase.start_event.record(torch.cuda.current_stream(device))
            phase.opt.zero_grad(set_to_none=True)
            phase.module.requires_grad_(True)

            # Accumulate gradients over multiple rounds.
            for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate(
                    zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c)):
                sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1)
                gain = phase.interval
                loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c,
                                          sync=sync, gain=gain)

            # Update weights.
            phase.module.requires_grad_(False)
            with torch.autograd.profiler.record_function(phase.name + '_opt'):
                for param in phase.module.parameters():
                    if param.grad is not None:
                        misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
                phase.opt.step()
            if phase.end_event is not None:
                phase.end_event.record(torch.cuda.current_stream(device))

        # Update G_ema.
        with torch.autograd.profiler.record_function('Gema'):
            ema_nimg = ema_kimg * 1000
            if ema_rampup is not None:
                ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
            ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8))
            for p_ema, p in zip(G_ema.parameters(), G.parameters()):
                p_ema.copy_(p.lerp(p_ema, ema_beta))
            for b_ema, b in zip(G_ema.buffers(), G.buffers()):
                b_ema.copy_(b)

        # Update state.
        cur_nimg += batch_size
        batch_idx += 1

        # Execute ADA heuristic.
        if (ada_stats is not None) and (batch_idx % ada_interval == 0):
            ada_stats.update()
            adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (
                        ada_kimg * 1000)
            augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device)))

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
            continue

        tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
        tick_end_time = time.time()
        total_time = tick_end_time - start_time
        tick_time = tick_end_time - tick_start_time
        left_kimg = total_kimg - cur_nimg / 1000
        left_sec = left_kimg * tick_time / tick_kimg
        finaltime = time.asctime(time.localtime(tick_end_time + left_sec))
        msg_final = ' %ss left till %s ' % (shortime(left_sec), finaltime[11:16])

        # Print status line, accumulating the same information in stats_collector.
        # tick_end_time = time.time()
        fields = []
        fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<4d}"]
        fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<6.1f}"]
        fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', total_time)):<8s}"]
        fields += [msg_final]
        fields += [f"min/tick {training_stats.report0('Timing/sec_per_tick', tick_time / 60):<6.3g}"]
        fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', tick_time / tick_kimg):<7.3g}"]
        fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2 ** 30):<4.1f}"]
        fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2 ** 30):<4.1f}"]
        torch.cuda.reset_peak_memory_stats()
        fields += [f"aug {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"]
        # !!!
        fields += ["lr %.2g" % training_stats.report0('Progress/G_lrate', G_opt_kwargs.lr)]
        fields += ["%.2g" % training_stats.report0('Progress/D_lrate', D_opt_kwargs.lr)]
        training_stats.report0('Timing/total_hours', total_time / (60 * 60))
        training_stats.report0('Timing/total_days', total_time / (24 * 60 * 60))
        if rank == 0:
            print(' '.join(fields))

        # Check for abort.
        if (not done) and (abort_fn is not None) and abort_fn():
            done = True
            if rank == 0:
                print()
                print('Aborting...')

        # Save image snapshot.
        if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0):
            images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
            save_image_grid(images, os.path.join(run_dir, 'fake-%04d.%s' % (cur_nimg // 1000, ext)), drange=[-1, 1],
                            grid_size=grid_size)

        # Save network snapshot.
        # snapshot_pkl = None
        if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0):
            snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs))
            compact_data = dict(training_set_kwargs=dict(training_set_kwargs))
            for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]:
                if module is not None:
                    if num_gpus > 1:
                        misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg')
                    module = copy.deepcopy(module).eval().requires_grad_(False).cpu()
                snapshot_data[name] = module
                if name == 'G_ema':
                    compact_data[name] = module  # !!!
                del module  # conserve memory
            # !!!
            pkl_snap = os.path.join(run_dir, '%s-%04d.pkl' % (savenames[0], cur_nimg // 1000))
            pkl_run = os.path.join(run_dir, '%s-%04d.pkl' % (savenames[1], cur_nimg // 1000))
            if rank == 0:
                with open(pkl_snap, 'wb') as f:
                    pickle.dump(snapshot_data, f)
                with open(pkl_run, 'wb') as f:
                    pickle.dump(compact_data, f)

        # Evaluate metrics.
        '''
        if (snapshot_data is not None) and (len(metrics) > 0):
            if rank == 0:
                print('Evaluating metrics...')
            for metric in metrics:
                result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'],
                    dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device)
                if rank == 0:
                    metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl)
                stats_metrics.update(result_dict.results)
        '''
        try:
            del snapshot_data  # conserve memory
        except UnboundLocalError:
            pass
        # !!!
        try:
            del compact_data  # conserve memory
        except UnboundLocalError:
            pass

        # Collect statistics.
        for phase in phases:
            value = []
            if (phase.start_event is not None) and (phase.end_event is not None):
                phase.end_event.synchronize()
                value = phase.start_event.elapsed_time(phase.end_event)
            training_stats.report0('Timing/' + phase.name, value)
        stats_collector.update()
        stats_dict = stats_collector.as_dict()

        # Report stats to database
        if rank == 0:
            url = f'https://mnr6yzqr22jgywm-adw2.adb.eu-frankfurt-1.oraclecloudapps.com/ords/thesisproject/sg2/' \
                  f'report/{run_id}/{hostname}'
            requests.post(url, data=json.dumps(stats_dict))

        # Update logs.
        timestamp = time.time()
        if stats_jsonl is not None:
            fields = dict(stats_dict, timestamp=timestamp)
            json_data = json.dumps(fields)

            # Write stats to file
            stats_jsonl.write(json_data + '\n')
            stats_jsonl.flush()
        # if stats_tfevents is not None:
        #     global_step = int(cur_nimg / 1e3)
        #     walltime = timestamp - start_time
        #     for name, value in stats_dict.items():
        #         stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime)
            # for name, value in stats_metrics.items():
            # stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime)
            # stats_tfevents.flush()
        if progress_fn is not None:
            progress_fn(cur_nimg // 1000, total_kimg)

        # Update state.
        cur_tick += 1
        tick_start_nimg = cur_nimg
        tick_start_time = time.time()
        # maintenance_time = tick_start_time - tick_end_time
        if done:
            break

    # Done.
    if rank == 0:
        print()
        print('Exiting...')
def expand_seed(seeds, vector_size):
  result = []
  for seed in seeds:
    rnd = np.random.RandomState(seed)
    result.append(rnd.randn(1, vector_size))
  return result

print("\nLATENT VECTOR WALK VIDEO GENERATOR made by JUSTIN GALLAGHER")
print("DO NOT INCLUDE QUOTATIONS IN ANY FILE OR DIRECTORY PATH NAMES!\n")
pkl_file = input('What is the .pkl file path: ')
save_path = input('Save video to which directory: ')

print(f'Loading networks from "{pkl_file}"...')
device = torch.device('cuda')
with dnnlib.util.open_url(pkl_file) as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device)

vector_size = G.z_dim
seeds = expand_seed( [8192+1,8192+9], vector_size)

seed_list = []
seed_num = int(input("How many seeds would you like to input: "))
print('Begin inputting seeds now...')
for i in range(0, seed_num):
    seed = int(input(f'Seed #{i+1}: '))
    seed_list.append(seed)

steps = int(input("Number of steps between seeds: "))

print('Creating temp directory for images...')
temp = os.path.join(save_path, "temp-images")
예제 #26
0
def training_loop(
        run_dir='.',  # Output directory.
        training_set_kwargs={},  # Options for training set.
        data_loader_kwargs={},  # Options for torch.utils.data.DataLoader.
        G_kwargs={},  # Options for generator network.
        D_kwargs={},  # Options for discriminator network.
        G_opt_kwargs={},  # Options for generator optimizer.
        D_opt_kwargs={},  # Options for discriminator optimizer.
        loss_kwargs={},  # Options for loss function.
        metrics=[],  # Metrics to evaluate during training.
        random_seed=0,  # Global random seed.
        num_gpus=1,  # Number of GPUs participating in the training.
        rank=0,  # Rank of the current process in [0, num_gpus[.
        batch_size=4,  # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
        batch_gpu=4,  # Number of samples processed at a time by one GPU.
        ema_kimg=10,  # Half-life of the exponential moving average (EMA) of generator weights.
        ema_rampup=0.05,  # EMA ramp-up coefficient. None = no rampup.
        G_reg_interval=None,  # How often to perform regularization for G? None = disable lazy regularization.
        D_reg_interval=16,  # How often to perform regularization for D? None = disable lazy regularization.
        total_kimg=25000,  # Total length of the training, measured in thousands of real images.
        kimg_per_tick=4,  # Progress snapshot interval.
        image_snapshot_ticks=50,  # How often to save image snapshots? None = disable.
        network_snapshot_ticks=50,  # How often to save network snapshots? None = disable.
        resume_pkl=None,  # Network pickle to resume training from.
        resume_kimg=0,  # First kimg to report when resuming training.
        cudnn_benchmark=True,  # Enable torch.backends.cudnn.benchmark?
        abort_fn=None,  # Callback function for determining whether to abort training. Must return consistent results across ranks.
        progress_fn=None,  # Callback function for updating training progress. Called for all ranks.
        restart_every=-1,  # Time interval in seconds to exit code
):
    # Initialize.
    start_time = time.time()
    device = torch.device('cuda', rank)
    np.random.seed(random_seed * num_gpus + rank)
    torch.manual_seed(random_seed * num_gpus + rank)
    torch.backends.cudnn.benchmark = cudnn_benchmark  # Improves training speed.
    torch.backends.cuda.matmul.allow_tf32 = False  # Improves numerical accuracy.
    torch.backends.cudnn.allow_tf32 = False  # Improves numerical accuracy.
    conv2d_gradfix.enabled = True  # Improves training speed.
    grid_sample_gradfix.enabled = True  # Avoids errors with the augmentation pipe.
    __RESTART__ = torch.tensor(
        0., device=device)  # will be broadcasted to exit loop
    __CUR_NIMG__ = torch.tensor(resume_kimg * 1000,
                                dtype=torch.long,
                                device=device)
    __CUR_TICK__ = torch.tensor(0, dtype=torch.long, device=device)
    __BATCH_IDX__ = torch.tensor(0, dtype=torch.long, device=device)
    __PL_MEAN__ = torch.zeros([], device=device)
    best_fid = 9999

    # Load training set.
    if rank == 0:
        print('Loading training set...')
    training_set = dnnlib.util.construct_class_by_name(
        **training_set_kwargs)  # subclass of training.dataset.Dataset
    training_set_sampler = misc.InfiniteSampler(dataset=training_set,
                                                rank=rank,
                                                num_replicas=num_gpus,
                                                seed=random_seed)
    training_set_iterator = iter(
        torch.utils.data.DataLoader(dataset=training_set,
                                    sampler=training_set_sampler,
                                    batch_size=batch_size // num_gpus,
                                    **data_loader_kwargs))
    if rank == 0:
        print()
        print('Num images: ', len(training_set))
        print('Image shape:', training_set.image_shape)
        print('Label shape:', training_set.label_shape)
        print()

    # Construct networks.
    if rank == 0:
        print('Constructing networks...')
    common_kwargs = dict(c_dim=training_set.label_dim,
                         img_resolution=training_set.resolution,
                         img_channels=training_set.num_channels)
    G = dnnlib.util.construct_class_by_name(
        **G_kwargs, **common_kwargs).train().requires_grad_(False).to(
            device)  # subclass of torch.nn.Module
    D = dnnlib.util.construct_class_by_name(
        **D_kwargs, **common_kwargs).train().requires_grad_(False).to(
            device)  # subclass of torch.nn.Module
    G_ema = copy.deepcopy(G).eval()

    # Check for existing checkpoint
    ckpt_pkl = None
    if restart_every > 0 and os.path.isfile(misc.get_ckpt_path(run_dir)):
        ckpt_pkl = resume_pkl = misc.get_ckpt_path(run_dir)

    # Resume from existing pickle.
    if (resume_pkl is not None) and (rank == 0):
        print(f'Resuming from "{resume_pkl}"')
        with dnnlib.util.open_url(resume_pkl) as f:
            resume_data = legacy.load_network_pkl(f)
        for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
            misc.copy_params_and_buffers(resume_data[name],
                                         module,
                                         require_all=False)

        if ckpt_pkl is not None:  # Load ticks
            __CUR_NIMG__ = resume_data['progress']['cur_nimg'].to(device)
            __CUR_TICK__ = resume_data['progress']['cur_tick'].to(device)
            __BATCH_IDX__ = resume_data['progress']['batch_idx'].to(device)
            __PL_MEAN__ = resume_data['progress'].get('pl_mean', torch.zeros(
                [])).to(device)
            best_fid = resume_data['progress'][
                'best_fid']  # only needed for rank == 0

        del resume_data

    # Print network summary tables.
    if rank == 0:
        z = torch.empty([batch_gpu, G.z_dim], device=device)
        c = torch.empty([batch_gpu, G.c_dim], device=device)
        img = misc.print_module_summary(G, [z, c])
        misc.print_module_summary(D, [img, c])

    # Distribute across GPUs.
    if rank == 0:
        print(f'Distributing across {num_gpus} GPUs...')
    for module in [G, D, G_ema]:
        if module is not None and num_gpus > 1:
            for param in misc.params_and_buffers(module):
                torch.distributed.broadcast(param, src=0)

    # Setup training phases.
    if rank == 0:
        print('Setting up training phases...')
    loss = dnnlib.util.construct_class_by_name(
        device=device, G=G, G_ema=G_ema, D=D,
        **loss_kwargs)  # subclass of training.loss.Loss
    phases = []
    for name, module, opt_kwargs, reg_interval in [
        ('G', G, G_opt_kwargs, G_reg_interval),
        ('D', D, D_opt_kwargs, D_reg_interval)
    ]:
        if reg_interval is None:
            opt = dnnlib.util.construct_class_by_name(
                params=module.parameters(),
                **opt_kwargs)  # subclass of torch.optim.Optimizer
            phases += [
                dnnlib.EasyDict(name=name + 'both',
                                module=module,
                                opt=opt,
                                interval=1)
            ]
        else:  # Lazy regularization.
            mb_ratio = reg_interval / (reg_interval + 1)
            opt_kwargs = dnnlib.EasyDict(opt_kwargs)
            opt_kwargs.lr = opt_kwargs.lr * mb_ratio
            opt_kwargs.betas = [beta**mb_ratio for beta in opt_kwargs.betas]
            opt = dnnlib.util.construct_class_by_name(
                module.parameters(),
                **opt_kwargs)  # subclass of torch.optim.Optimizer
            phases += [
                dnnlib.EasyDict(name=name + 'main',
                                module=module,
                                opt=opt,
                                interval=1)
            ]
            phases += [
                dnnlib.EasyDict(name=name + 'reg',
                                module=module,
                                opt=opt,
                                interval=reg_interval)
            ]
    for phase in phases:
        phase.start_event = None
        phase.end_event = None
        if rank == 0:
            phase.start_event = torch.cuda.Event(enable_timing=True)
            phase.end_event = torch.cuda.Event(enable_timing=True)

    # Export sample images.
    grid_size = None
    grid_z = None
    grid_c = None
    if rank == 0:
        print('Exporting sample images...')
        grid_size, images, labels = setup_snapshot_image_grid(
            training_set=training_set)
        save_image_grid(images,
                        os.path.join(run_dir, 'reals.png'),
                        drange=[0, 255],
                        grid_size=grid_size)

        grid_z = torch.randn([labels.shape[0], G.z_dim],
                             device=device).split(batch_gpu)
        grid_c = torch.from_numpy(labels).to(device).split(batch_gpu)
        images = torch.cat([
            G_ema(z=z, c=c, noise_mode='const').cpu()
            for z, c in zip(grid_z, grid_c)
        ]).numpy()

        save_image_grid(images,
                        os.path.join(run_dir, 'fakes_init.png'),
                        drange=[-1, 1],
                        grid_size=grid_size)

    # Initialize logs.
    if rank == 0:
        print('Initializing logs...')
    stats_collector = training_stats.Collector(regex='.*')
    stats_metrics = dict()
    stats_jsonl = None
    stats_tfevents = None
    if rank == 0:
        stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt')
        try:
            import torch.utils.tensorboard as tensorboard
            stats_tfevents = tensorboard.SummaryWriter(run_dir)
        except ImportError as err:
            print('Skipping tfevents export:', err)

    # Train.
    if rank == 0:
        print(f'Training for {total_kimg} kimg...')
        print()
    if num_gpus > 1:  # broadcast loaded states to all
        torch.distributed.broadcast(__CUR_NIMG__, 0)
        torch.distributed.broadcast(__CUR_TICK__, 0)
        torch.distributed.broadcast(__BATCH_IDX__, 0)
        torch.distributed.broadcast(__PL_MEAN__, 0)
        torch.distributed.barrier()  # ensure all processes received this info
    cur_nimg = __CUR_NIMG__.item()
    cur_tick = __CUR_TICK__.item()
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    maintenance_time = tick_start_time - start_time
    batch_idx = __BATCH_IDX__.item()
    if progress_fn is not None:
        progress_fn(cur_nimg // 1000, total_kimg)
    if hasattr(loss, 'pl_mean'):
        loss.pl_mean.copy_(__PL_MEAN__)
    while True:

        with torch.autograd.profiler.record_function('data_fetch'):
            phase_real_img, phase_real_c = next(training_set_iterator)
            phase_real_img = (
                phase_real_img.to(device).to(torch.float32) / 127.5 -
                1).split(batch_gpu)
            phase_real_c = phase_real_c.to(device).split(batch_gpu)
            all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim],
                                    device=device)
            all_gen_z = [
                phase_gen_z.split(batch_gpu)
                for phase_gen_z in all_gen_z.split(batch_size)
            ]
            all_gen_c = [
                training_set.get_label(np.random.randint(len(training_set)))
                for _ in range(len(phases) * batch_size)
            ]
            all_gen_c = torch.from_numpy(
                np.stack(all_gen_c)).pin_memory().to(device)
            all_gen_c = [
                phase_gen_c.split(batch_gpu)
                for phase_gen_c in all_gen_c.split(batch_size)
            ]

        # Execute training phases.
        for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z,
                                                   all_gen_c):
            if batch_idx % phase.interval != 0:
                continue
            if phase.start_event is not None:
                phase.start_event.record(torch.cuda.current_stream(device))

            # Accumulate gradients.
            phase.opt.zero_grad(set_to_none=True)
            phase.module.requires_grad_(True)

            if phase.name in ['Dmain', 'Dboth', 'Dreg']:
                phase.module.feature_network.requires_grad_(False)

            for real_img, real_c, gen_z, gen_c in zip(phase_real_img,
                                                      phase_real_c,
                                                      phase_gen_z,
                                                      phase_gen_c):
                loss.accumulate_gradients(phase=phase.name,
                                          real_img=real_img,
                                          real_c=real_c,
                                          gen_z=gen_z,
                                          gen_c=gen_c,
                                          gain=phase.interval,
                                          cur_nimg=cur_nimg)
            phase.module.requires_grad_(False)

            # Update weights.
            with torch.autograd.profiler.record_function(phase.name + '_opt'):
                params = [
                    param for param in phase.module.parameters()
                    if param.grad is not None
                ]
                if len(params) > 0:
                    flat = torch.cat(
                        [param.grad.flatten() for param in params])
                    if num_gpus > 1:
                        torch.distributed.all_reduce(flat)
                        flat /= num_gpus
                    misc.nan_to_num(flat,
                                    nan=0,
                                    posinf=1e5,
                                    neginf=-1e5,
                                    out=flat)
                    grads = flat.split([param.numel() for param in params])
                    for param, grad in zip(params, grads):
                        param.grad = grad.reshape(param.shape)
                phase.opt.step()

            # Phase done.
            if phase.end_event is not None:
                phase.end_event.record(torch.cuda.current_stream(device))

        # Update G_ema.
        with torch.autograd.profiler.record_function('Gema'):
            ema_nimg = ema_kimg * 1000
            if ema_rampup is not None:
                ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
            ema_beta = 0.5**(batch_size / max(ema_nimg, 1e-8))
            for p_ema, p in zip(G_ema.parameters(), G.parameters()):
                p_ema.copy_(p.lerp(p_ema, ema_beta))
            for b_ema, b in zip(G_ema.buffers(), G.buffers()):
                b_ema.copy_(b)

        # Update state.
        cur_nimg += batch_size
        batch_idx += 1

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if (not done) and (cur_tick != 0) and (
                cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
            continue

        # Print status line, accumulating the same information in training_stats.
        tick_end_time = time.time()
        fields = []
        fields += [
            f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"
        ]
        fields += [
            f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"
        ]
        fields += [
            f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"
        ]
        fields += [
            f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"
        ]
        fields += [
            f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"
        ]
        fields += [
            f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"
        ]
        fields += [
            f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"
        ]
        fields += [
            f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"
        ]
        fields += [
            f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"
        ]
        torch.cuda.reset_peak_memory_stats()
        training_stats.report0('Timing/total_hours',
                               (tick_end_time - start_time) / (60 * 60))
        training_stats.report0('Timing/total_days',
                               (tick_end_time - start_time) / (24 * 60 * 60))
        if rank == 0:
            print(' '.join(fields))

        # Check for abort.
        if (not done) and (abort_fn is not None) and abort_fn():
            done = True
            if rank == 0:
                print()
                print('Aborting...')

        # Check for restart.
        if (rank == 0) and (restart_every > 0) and (time.time() - start_time >
                                                    restart_every):
            print('Restart job...')
            __RESTART__ = torch.tensor(1., device=device)
        if num_gpus > 1:
            torch.distributed.broadcast(__RESTART__, 0)
        if __RESTART__:
            done = True
            print(f'Process {rank} leaving...')
            if num_gpus > 1:
                torch.distributed.barrier()

        # Save image snapshot.
        if (rank == 0) and (image_snapshot_ticks is not None) and (
                done or cur_tick % image_snapshot_ticks == 0):
            images = torch.cat([
                G_ema(z=z, c=c, noise_mode='const').cpu()
                for z, c in zip(grid_z, grid_c)
            ]).numpy()
            save_image_grid(images,
                            os.path.join(run_dir,
                                         f'fakes{cur_nimg//1000:06d}.png'),
                            drange=[-1, 1],
                            grid_size=grid_size)

        # Save network snapshot.
        snapshot_pkl = None
        snapshot_data = None
        if (network_snapshot_ticks
                is not None) and (done
                                  or cur_tick % network_snapshot_ticks == 0):
            snapshot_data = dict(G=G,
                                 D=D,
                                 G_ema=G_ema,
                                 training_set_kwargs=dict(training_set_kwargs))
            for key, value in snapshot_data.items():
                if isinstance(value, torch.nn.Module):
                    snapshot_data[key] = value
                del value  # conserve memory

        # Save Checkpoint if needed
        if (rank == 0) and (restart_every > 0) and (
                network_snapshot_ticks
                is not None) and (done
                                  or cur_tick % network_snapshot_ticks == 0):
            snapshot_pkl = misc.get_ckpt_path(run_dir)
            # save as tensors to avoid error for multi GPU
            snapshot_data['progress'] = {
                'cur_nimg': torch.LongTensor([cur_nimg]),
                'cur_tick': torch.LongTensor([cur_tick]),
                'batch_idx': torch.LongTensor([batch_idx]),
                'best_fid': best_fid,
            }
            if hasattr(loss, 'pl_mean'):
                snapshot_data['progress']['pl_mean'] = loss.pl_mean.cpu()

            with open(snapshot_pkl, 'wb') as f:
                pickle.dump(snapshot_data, f)

        # Evaluate metrics.
        # if (snapshot_data is not None) and (len(metrics) > 0):
        if cur_tick and (snapshot_data is not None) and (len(metrics) > 0):
            if rank == 0:
                print('Evaluating metrics...')
            for metric in metrics:
                result_dict = metric_main.calc_metric(
                    metric=metric,
                    G=snapshot_data['G_ema'],
                    run_dir=run_dir,
                    cur_nimg=cur_nimg,
                    dataset_kwargs=training_set_kwargs,
                    num_gpus=num_gpus,
                    rank=rank,
                    device=device)
                if rank == 0:
                    metric_main.report_metric(result_dict,
                                              run_dir=run_dir,
                                              snapshot_pkl=snapshot_pkl)
                stats_metrics.update(result_dict.results)

            # save best fid ckpt
            snapshot_pkl = os.path.join(run_dir, f'best_model.pkl')
            cur_nimg_txt = os.path.join(run_dir, f'best_nimg.txt')
            if rank == 0:
                if 'fid50k_full' in stats_metrics and stats_metrics[
                        'fid50k_full'] < best_fid:
                    best_fid = stats_metrics['fid50k_full']

                    with open(snapshot_pkl, 'wb') as f:
                        dill.dump(snapshot_data, f)
                    # save curr iteration number (directly saving it to pkl leads to problems with multi GPU)
                    with open(cur_nimg_txt, 'w') as f:
                        f.write(str(cur_nimg))

        del snapshot_data  # conserve memory

        # Collect statistics.
        for phase in phases:
            value = []
            if (phase.start_event is not None) and (phase.end_event is not None) and \
                    not (phase.start_event.cuda_event == 0 and phase.end_event.cuda_event == 0):            # Both events were not initialized yet, can happen with restart
                phase.end_event.synchronize()
                value = phase.start_event.elapsed_time(phase.end_event)
            training_stats.report0('Timing/' + phase.name, value)
        stats_collector.update()
        stats_dict = stats_collector.as_dict()

        # Update logs.
        timestamp = time.time()
        if stats_jsonl is not None:
            fields = dict(stats_dict, timestamp=timestamp)
            stats_jsonl.write(json.dumps(fields) + '\n')
            stats_jsonl.flush()
        if stats_tfevents is not None:
            global_step = int(cur_nimg / 1e3)
            walltime = timestamp - start_time
            for name, value in stats_dict.items():
                stats_tfevents.add_scalar(name,
                                          value.mean,
                                          global_step=global_step,
                                          walltime=walltime)
            for name, value in stats_metrics.items():
                stats_tfevents.add_scalar(f'Metrics/{name}',
                                          value,
                                          global_step=global_step,
                                          walltime=walltime)
            stats_tfevents.flush()
        if progress_fn is not None:
            progress_fn(cur_nimg // 1000, total_kimg)

        # Update state.
        cur_tick += 1
        tick_start_nimg = cur_nimg
        tick_start_time = time.time()
        maintenance_time = tick_start_time - tick_end_time
        if done:
            break

    # Done.
    if rank == 0:
        print()
        print('Exiting...')
예제 #27
0
def generate_style_mix(
    network_pkl: str,
    row_seeds: List[int],
    col_seeds: List[int],
    col_styles: List[int],
    truncation_psi: float,
    noise_mode: str,
    outdir: str,
):
    """Generate images using pretrained network pickle.

    Examples:

    \b
    python style_mixing.py --outdir=out --rows=85,100,75,458,1500 --cols=55,821,1789,293 \\
        --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
    """
    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device("cuda")
    with dnnlib.util.open_url(network_pkl) as f:
        G = legacy.load_network_pkl(f)["G_ema"].to(device)  # type: ignore

    os.makedirs(outdir, exist_ok=True)

    print("Generating W vectors...")
    all_seeds = list(set(row_seeds + col_seeds))
    all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])
    all_w = G.mapping(torch.from_numpy(all_z).to(device), None)
    w_avg = G.mapping.w_avg
    all_w = w_avg + (all_w - w_avg) * truncation_psi
    w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))}

    print("Generating images...")
    all_images = G.synthesis(all_w, noise_mode=noise_mode)
    all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
    image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))}

    print("Generating style-mixed images...")
    for row_seed in row_seeds:
        for col_seed in col_seeds:
            w = w_dict[row_seed].clone()
            w[col_styles] = w_dict[col_seed][col_styles]
            image = G.synthesis(w[np.newaxis], noise_mode=noise_mode)
            image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
            image_dict[(row_seed, col_seed)] = image[0].cpu().numpy()

    print("Saving images...")
    os.makedirs(outdir, exist_ok=True)
    for (row_seed, col_seed), image in image_dict.items():
        PIL.Image.fromarray(image, "RGB").save(f"{outdir}/{row_seed}-{col_seed}.png")

    print("Saving image grid...")
    W = G.img_resolution
    H = G.img_resolution
    canvas = PIL.Image.new("RGB", (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), "black")
    for row_idx, row_seed in enumerate([0] + row_seeds):
        for col_idx, col_seed in enumerate([0] + col_seeds):
            if row_idx == 0 and col_idx == 0:
                continue
            key = (row_seed, col_seed)
            if row_idx == 0:
                key = (col_seed, col_seed)
            if col_idx == 0:
                key = (row_seed, row_seed)
            canvas.paste(PIL.Image.fromarray(image_dict[key], "RGB"), (W * col_idx, H * row_idx))
    canvas.save(f"{outdir}/grid.png")
예제 #28
0
def main():
    if a.vector_dir is not None:
        if a.vector_dir.endswith('/') or a.vector_dir.endswith('\\'):
            a.vector_dir = a.vector_dir[:-1]
    os.makedirs(a.out_dir, exist_ok=True)
    device = torch.device('cuda')

    global Gs, use_d, custom

    # setup generator
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.verbose = a.verbose
    Gs_kwargs.size = a.size
    Gs_kwargs.scale_type = a.scale_type

    # load base or custom network
    pkl_name = osp.splitext(a.model)[0]
    if '.pkl' in a.model.lower():
        custom = False
        print(' .. Gs from pkl ..', basename(a.model))
    else:
        custom = True
        print(' .. Gs custom ..', basename(a.model))
    with dnnlib.util.open_url(pkl_name + '.pkl') as f:
        Gs = legacy.load_network_pkl(f,
                                     custom=custom, **Gs_kwargs)['G_ema'].to(
                                         device)  # type: ignore

    # load directions
    if a.vector_dir is not None:
        directions = []
        vector_list = file_list(a.vector_dir, 'npy')
        for v in vector_list:
            direction = load_latents(v)
            if len(direction.shape) == 2:
                direction = np.expand_dims(direction, 0)
            directions.append(direction)
        directions = np.concatenate(directions)[:,
                                                np.newaxis]  # [frm,1,18,512]
    else:
        print(' No vectors found')
        exit()

    if len(direction[0].shape) > 1 and direction[0].shape[0] > 1:
        use_d = True
    print(' directions', directions.shape, 'using d' if use_d else 'using w')
    directions = torch.from_numpy(directions).to(device)

    # latent direction range
    lrange = [-0.5, 0.5]

    # load saved latents
    if a.base_lat is not None:
        base_latent = load_latents(a.base_lat)
        base_latent = torch.from_numpy(base_latent).to(device)
    else:
        print(' No NPY input given, making random')
        base_latent = np.random.randn(1, Gs.z_dim)
        if use_d:
            base_latent = Gs.mapping(base_latent, None)  # [frm,18,512]

    pbar = ProgressBar(len(directions))
    for i, direction in enumerate(directions):
        make_loop(base_latent, direction, lrange, a.fstep * 2, a.fstep * 2 * i)
        pbar.upd()
예제 #29
0
def training_loop(
        run_dir='.',  # Output directory.
        training_set_kwargs={},  # Options for training set.
        data_loader_kwargs={},  # Options for torch.utils.data.DataLoader.
        G_kwargs={},  # Options for generator network.
        D_kwargs={},  # Options for discriminator network.
        D2_kwargs={},  # Options for discriminator network.
        G_opt_kwargs={},  # Options for generator optimizer.
        D_opt_kwargs={},  # Options for discriminator optimizer.
        augment_kwargs=None,  # Options for augmentation pipeline. None = disable.
        loss_kwargs={},  # Options for loss function.
        metrics=[],  # Metrics to evaluate during training.
        random_seed=0,  # Global random seed.
        num_gpus=1,  # Number of GPUs participating in the training.
        rank=0,  # Rank of the current process in [0, num_gpus[.
        batch_size=4,  # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
        batch_gpu=4,  # Number of samples processed at a time by one GPU.
        ema_kimg=10,  # Half-life of the exponential moving average (EMA) of generator weights.
        ema_rampup=None,  # EMA ramp-up coefficient.
        G_reg_interval=4,  # How often to perform regularization for G? None = disable lazy regularization.
        D_reg_interval=16,  # How often to perform regularization for D? None = disable lazy regularization.
        augment_p=0,  # Initial value of augmentation probability.
        ada_target=None,  # ADA target value. None = fixed p.
        ada_interval=4,  # How often to perform ADA adjustment?
        ada_kimg=500,  # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
        total_kimg=25000,  # Total length of the training, measured in thousands of real images.
        kimg_per_tick=4,  # Progress snapshot interval.
        image_snapshot_ticks=50,  # How often to save image snapshots? None = disable.
        network_snapshot_ticks=50,  # How often to save network snapshots? None = disable.
        resume_pkl=None,  # Network pickle to resume training from.
        cudnn_benchmark=True,  # Enable torch.backends.cudnn.benchmark?
        abort_fn=None,  # Callback function for determining whether to abort training. Must return consistent results across ranks.
        progress_fn=None,  # Callback function for updating training progress. Called for all ranks.
        obake=None,  # Obake training: <bool>, default = False
):
    # Initialize.
    start_time = time.time()
    device = torch.device('cuda', rank)
    np.random.seed(random_seed * num_gpus + rank)
    torch.manual_seed(random_seed * num_gpus + rank)
    torch.backends.cudnn.benchmark = cudnn_benchmark  # Improves training speed.
    conv2d_gradfix.enabled = True  # Improves training speed.
    grid_sample_gradfix.enabled = True  # Avoids errors with the augmentation pipe.

    # Load training set.
    if rank == 0:
        print('Loading training set...')
    training_set = dnnlib.util.construct_class_by_name(
        **training_set_kwargs)  # subclass of training.dataset.Dataset
    training_set_sampler = misc.InfiniteSampler(dataset=training_set,
                                                rank=rank,
                                                num_replicas=num_gpus,
                                                seed=random_seed)
    training_set_iterator = iter(
        torch.utils.data.DataLoader(dataset=training_set,
                                    sampler=training_set_sampler,
                                    batch_size=batch_size // num_gpus,
                                    **data_loader_kwargs))
    if rank == 0:
        print()
        print('Num images: ', len(training_set))
        print('Image shape:', training_set.image_shape)
        print('Label shape:', training_set.label_shape)
        print()

    # Construct networks.
    if rank == 0:
        print('Constructing networks...')
    common_kwargs = dict(c_dim=training_set.label_dim,
                         img_resolution=training_set.resolution,
                         img_channels=training_set.num_channels)
    G = dnnlib.util.construct_class_by_name(
        **G_kwargs, **common_kwargs).train().requires_grad_(False).to(
            device)  # subclass of torch.nn.Module
    D = dnnlib.util.construct_class_by_name(
        **D_kwargs, **common_kwargs).train().requires_grad_(False).to(
            device)  # subclass of torch.nn.Module
    G_ema = copy.deepcopy(G).eval()
    if obake is not None:
        D_mtcnn = MTCNN(image_size=D2_kwargs.mtcnn_output_size,
                        margin=D2_kwargs.mtcnn_output_margin,
                        thresholds=D2_kwargs.mtcnn_thresholds)
        D_face = InceptionResnetV1(pretrained=D2_kwargs.resnet_type).eval()

    # Resume from existing pickle.
    if (resume_pkl is not None) and (rank == 0):
        print(f'Resuming from "{resume_pkl}"')
        with dnnlib.util.open_url(resume_pkl) as f:
            resume_data = legacy.load_network_pkl(f)
        for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
            misc.copy_params_and_buffers(resume_data[name],
                                         module,
                                         require_all=False)

    # Print network summary tables.
    if rank == 0:
        z = torch.empty([batch_gpu, G.z_dim], device=device)
        c = torch.empty([batch_gpu, G.c_dim], device=device)
        img = misc.print_module_summary(G, [z, c])
        misc.print_module_summary(D, [img, c])

    # Setup augmentation.
    if rank == 0:
        print('Setting up augmentation...')
    augment_pipe = None
    ada_stats = None
    if (augment_kwargs is not None) and (augment_p > 0
                                         or ada_target is not None):
        augment_pipe = dnnlib.util.construct_class_by_name(
            **augment_kwargs).train().requires_grad_(False).to(
                device)  # subclass of torch.nn.Module
        augment_pipe.p.copy_(torch.as_tensor(augment_p))
        if ada_target is not None:
            ada_stats = training_stats.Collector(regex='Loss/signs/real')

    # Distribute across GPUs.
    if rank == 0:
        print(f'Distributing across {num_gpus} GPUs...')
    ddp_modules = dict()
    if obake is not None:
        for name, module in [('G_mapping', G.mapping),
                             ('G_synthesis', G.synthesis), ('D', D),
                             ('D_mtcnn', D_mtcnn), ('D_face', D_face),
                             (None, G_ema), ('augment_pipe', augment_pipe)]:
            if (num_gpus > 1) and (module is not None) and len(
                    list(module.parameters())) != 0:
                module.requires_grad_(True)
                module = torch.nn.parallel.DistributedDataParallel(
                    module, device_ids=[device], broadcast_buffers=False)
                module.requires_grad_(False)
            if name is not None:
                ddp_modules[name] = module
    else:
        for name, module in [('G_mapping', G.mapping),
                             ('G_synthesis', G.synthesis), ('D', D),
                             (None, G_ema), ('augment_pipe', augment_pipe)]:
            if (num_gpus > 1) and (module is not None) and len(
                    list(module.parameters())) != 0:
                module.requires_grad_(True)
                module = torch.nn.parallel.DistributedDataParallel(
                    module, device_ids=[device], broadcast_buffers=False)
                module.requires_grad_(False)
            if name is not None:
                ddp_modules[name] = module

    # Setup training phases.
    if rank == 0:
        print('Setting up training phases...')
    loss = dnnlib.util.construct_class_by_name(
        device=device, **ddp_modules,
        **loss_kwargs)  # subclass of training.loss.Loss
    phases = []
    for name, module, opt_kwargs, reg_interval in [
        ('G', G, G_opt_kwargs, G_reg_interval),
        ('D', D, D_opt_kwargs, D_reg_interval)
    ]:
        if reg_interval is None:
            opt = dnnlib.util.construct_class_by_name(
                params=module.parameters(),
                **opt_kwargs)  # subclass of torch.optim.Optimizer
            phases += [
                dnnlib.EasyDict(name=name + 'both',
                                module=module,
                                opt=opt,
                                interval=1)
            ]
        else:  # Lazy regularization.
            mb_ratio = reg_interval / (reg_interval + 1)
            opt_kwargs = dnnlib.EasyDict(opt_kwargs)
            opt_kwargs.lr = opt_kwargs.lr * mb_ratio
            opt_kwargs.betas = [beta**mb_ratio for beta in opt_kwargs.betas]
            opt = dnnlib.util.construct_class_by_name(
                module.parameters(),
                **opt_kwargs)  # subclass of torch.optim.Optimizer
            phases += [
                dnnlib.EasyDict(name=name + 'main',
                                module=module,
                                opt=opt,
                                interval=1)
            ]
            phases += [
                dnnlib.EasyDict(name=name + 'reg',
                                module=module,
                                opt=opt,
                                interval=reg_interval)
            ]
    for phase in phases:
        phase.start_event = None
        phase.end_event = None
        if rank == 0:
            phase.start_event = torch.cuda.Event(enable_timing=True)
            phase.end_event = torch.cuda.Event(enable_timing=True)

    # Export sample images.
    grid_size = None
    grid_z = None
    grid_c = None
    if rank == 0:
        print('Exporting sample images...')
        grid_size, images, labels = setup_snapshot_image_grid(
            training_set=training_set)
        save_image_grid(images,
                        os.path.join(run_dir, 'reals.png'),
                        drange=[0, 255],
                        grid_size=grid_size)
        grid_z = torch.randn([labels.shape[0], G.z_dim],
                             device=device).split(batch_gpu)
        grid_c = torch.from_numpy(labels).to(device).split(batch_gpu)
        images = torch.cat([
            G_ema(z=z, c=c, noise_mode='const').cpu()
            for z, c in zip(grid_z, grid_c)
        ]).numpy()
        save_image_grid(images,
                        os.path.join(run_dir, 'fakes_init.png'),
                        drange=[-1, 1],
                        grid_size=grid_size)

    # Initialize logs.
    if rank == 0:
        print('Initializing logs...')
    stats_collector = training_stats.Collector(regex='.*')
    stats_metrics = dict()
    stats_jsonl = None
    stats_tfevents = None
    if rank == 0:
        stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt')
        try:
            import torch.utils.tensorboard as tensorboard
            stats_tfevents = tensorboard.SummaryWriter(run_dir)
        except ImportError as err:
            print('Skipping tfevents export:', err)

    # Train.
    if rank == 0:
        print(f'Training for {total_kimg} kimg...')
        print()
    cur_nimg = 0
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    maintenance_time = tick_start_time - start_time
    batch_idx = 0
    if progress_fn is not None:
        progress_fn(0, total_kimg)
    while True:

        # Fetch training data.
        with torch.autograd.profiler.record_function('data_fetch'):
            phase_real_img, phase_real_c = next(training_set_iterator)
            phase_real_img = (
                phase_real_img.to(device).to(torch.float32) / 127.5 -
                1).split(batch_gpu)
            phase_real_c = phase_real_c.to(device).split(batch_gpu)
            all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim],
                                    device=device)
            all_gen_z = [
                phase_gen_z.split(batch_gpu)
                for phase_gen_z in all_gen_z.split(batch_size)
            ]
            all_gen_c = [
                training_set.get_label(np.random.randint(len(training_set)))
                for _ in range(len(phases) * batch_size)
            ]
            all_gen_c = torch.from_numpy(
                np.stack(all_gen_c)).pin_memory().to(device)
            all_gen_c = [
                phase_gen_c.split(batch_gpu)
                for phase_gen_c in all_gen_c.split(batch_size)
            ]

        # Execute training phases.
        for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z,
                                                   all_gen_c):
            if batch_idx % phase.interval != 0:
                continue

            # Initialize gradient accumulation.
            if phase.start_event is not None:
                phase.start_event.record(torch.cuda.current_stream(device))
            phase.opt.zero_grad(set_to_none=True)
            phase.module.requires_grad_(True)

            # Accumulate gradients over multiple rounds.
            for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate(
                    zip(phase_real_img, phase_real_c, phase_gen_z,
                        phase_gen_c)):
                sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1)
                gain = phase.interval
                print(phase.name)
                loss.accumulate_gradients(phase=phase.name,
                                          real_img=real_img,
                                          real_c=real_c,
                                          gen_z=gen_z,
                                          gen_c=gen_c,
                                          sync=sync,
                                          gain=gain)

            # Update weights.
            phase.module.requires_grad_(False)
            with torch.autograd.profiler.record_function(phase.name + '_opt'):
                for param in phase.module.parameters():
                    if param.grad is not None:
                        misc.nan_to_num(param.grad,
                                        nan=0,
                                        posinf=1e5,
                                        neginf=-1e5,
                                        out=param.grad)
                phase.opt.step()
            if phase.end_event is not None:
                phase.end_event.record(torch.cuda.current_stream(device))

        # Update G_ema.
        with torch.autograd.profiler.record_function('Gema'):
            ema_nimg = ema_kimg * 1000
            if ema_rampup is not None:
                ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
            ema_beta = 0.5**(batch_size / max(ema_nimg, 1e-8))
            for p_ema, p in zip(G_ema.parameters(), G.parameters()):
                p_ema.copy_(p.lerp(p_ema, ema_beta))
            for b_ema, b in zip(G_ema.buffers(), G.buffers()):
                b_ema.copy_(b)

        # Update state.
        cur_nimg += batch_size
        batch_idx += 1

        # Execute ADA heuristic.
        if (ada_stats is not None) and (batch_idx % ada_interval == 0):
            ada_stats.update()
            adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (
                batch_size * ada_interval) / (ada_kimg * 1000)
            augment_pipe.p.copy_(
                (augment_pipe.p + adjust).max(misc.constant(0, device=device)))

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if (not done) and (cur_tick != 0) and (
                cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
            continue

        # Print status line, accumulating the same information in stats_collector.
        tick_end_time = time.time()
        fields = []
        fields += [
            f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"
        ]
        fields += [
            f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"
        ]
        fields += [
            f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"
        ]
        fields += [
            f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"
        ]
        fields += [
            f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"
        ]
        fields += [
            f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"
        ]
        fields += [
            f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"
        ]
        fields += [
            f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"
        ]
        torch.cuda.reset_peak_memory_stats()
        fields += [
            f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"
        ]
        training_stats.report0('Timing/total_hours',
                               (tick_end_time - start_time) / (60 * 60))
        training_stats.report0('Timing/total_days',
                               (tick_end_time - start_time) / (24 * 60 * 60))
        if rank == 0:
            print(' '.join(fields))

        # Check for abort.
        if (not done) and (abort_fn is not None) and abort_fn():
            done = True
            if rank == 0:
                print()
                print('Aborting...')

        # Save image snapshot.
        if (rank == 0) and (image_snapshot_ticks is not None) and (
                done or cur_tick % image_snapshot_ticks == 0):
            images = torch.cat([
                G_ema(z=z, c=c, noise_mode='const').cpu()
                for z, c in zip(grid_z, grid_c)
            ]).numpy()
            save_image_grid(images,
                            os.path.join(run_dir,
                                         f'fakes{cur_nimg//1000:06d}.png'),
                            drange=[-1, 1],
                            grid_size=grid_size)

        # Save network snapshot.
        snapshot_pkl = None
        snapshot_data = None
        if (network_snapshot_ticks
                is not None) and (done
                                  or cur_tick % network_snapshot_ticks == 0):
            snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs))
            for name, module in [('G', G), ('D', D), ('G_ema', G_ema),
                                 ('augment_pipe', augment_pipe)]:
                if module is not None:
                    if num_gpus > 1:
                        misc.check_ddp_consistency(module,
                                                   ignore_regex=r'.*\.w_avg')
                    module = copy.deepcopy(module).eval().requires_grad_(
                        False).cpu()
                snapshot_data[name] = module
                del module  # conserve memory
            snapshot_pkl = os.path.join(
                run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl')
            if rank == 0:
                with open(snapshot_pkl, 'wb') as f:
                    pickle.dump(snapshot_data, f)

        # Evaluate metrics.
        if (snapshot_data is not None) and (len(metrics) > 0):
            if rank == 0:
                print('Evaluating metrics...')
            for metric in metrics:
                result_dict = metric_main.calc_metric(
                    metric=metric,
                    G=snapshot_data['G_ema'],
                    dataset_kwargs=training_set_kwargs,
                    num_gpus=num_gpus,
                    rank=rank,
                    device=device)
                if rank == 0:
                    metric_main.report_metric(result_dict,
                                              run_dir=run_dir,
                                              snapshot_pkl=snapshot_pkl)
                stats_metrics.update(result_dict.results)
        del snapshot_data  # conserve memory

        # Collect statistics.
        for phase in phases:
            value = []
            if (phase.start_event is not None) and (phase.end_event
                                                    is not None):
                phase.end_event.synchronize()
                value = phase.start_event.elapsed_time(phase.end_event)
            training_stats.report0('Timing/' + phase.name, value)
        stats_collector.update()
        stats_dict = stats_collector.as_dict()

        # Update logs.
        timestamp = time.time()
        if stats_jsonl is not None:
            fields = dict(stats_dict, timestamp=timestamp)
            stats_jsonl.write(json.dumps(fields) + '\n')
            stats_jsonl.flush()
        if stats_tfevents is not None:
            global_step = int(cur_nimg / 1e3)
            walltime = timestamp - start_time
            for name, value in stats_dict.items():
                stats_tfevents.add_scalar(name,
                                          value.mean,
                                          global_step=global_step,
                                          walltime=walltime)
            for name, value in stats_metrics.items():
                stats_tfevents.add_scalar(f'Metrics/{name}',
                                          value,
                                          global_step=global_step,
                                          walltime=walltime)
            stats_tfevents.flush()
        if progress_fn is not None:
            progress_fn(cur_nimg // 1000, total_kimg)

        # Update state.
        cur_tick += 1
        tick_start_nimg = cur_nimg
        tick_start_time = time.time()
        maintenance_time = tick_start_time - tick_end_time
        if done:
            break

    # Done.
    if rank == 0:
        print()
        print('Exiting...')
    def forward(self, inputs):
        return [inputs[0][:, :, self.ystart:self.yend, self.xstart:self.xend]]


cv2.dnn_registerLayer('Crop', CropLayer)
pretrained_model = "models/edge_detection/hed_pretrained_bsds.caffemodel"
model_def = "models/edge_detection/deploy.prototxt"
edge_net = cv2.dnn.readNetFromCaffe(model_def, pretrained_model)

truncation_psi = 0.4
seed = 150
device = torch.device('cuda')
network_pkl = "network-snapshot-000712.pkl"
with dnnlib.util.open_url(network_pkl) as fp:
    G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(
        device)  # type: ignore

G = copy.deepcopy(G).eval().requires_grad_(False).to(device)

#z = torch.from_numpy(np.random.RandomState().randn(1, G.z_dim)).to(device)
#img = G(z, None, truncation_psi=truncation_psi)
#img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)

# Load VGG16 feature detector.
vgg16 = torch.hub.load('pytorch/vision:v0.6.0', 'vgg16', pretrained=True)

# Replace last layer
vgg16.classifier[-1] = nn.Linear(4096, 512)
vgg16.classifier
load(vgg16)
vgg16 = vgg16.train().to(device)