def create_blended_model( ctx: click.Context, lower_res_pkl: str, higher_res_pkl: str, model_res: Optional[int], resolution: Optional[int], out: Optional[str], ): G_kwargs = dnnlib.EasyDict() with dnnlib.util.open_url(lower_res_pkl) as f: lo = legacy.load_network_pkl(f, custom=False, **G_kwargs) # type: ignore lo_G, lo_D, lo_G_ema = lo['G'], lo['D'], lo['G_ema'] with dnnlib.util.open_url(higher_res_pkl) as f: hi = legacy.load_network_pkl(f, custom=False, **G_kwargs)['G_ema'] # type: ignore model_out = blend_models(lo_G_ema, hi, model_res, resolution) # for n in model_out.named_parameters(): # print(n[0]) data = dict([('G', None), ('D', None), ('G_ema', None)]) with open(out, 'wb') as f: data['G'] = lo_G data['D'] = lo_D data['G_ema'] = model_out pickle.dump(data, f)
def __init__(self, domain, ckpt_path=None, load_encoder=False, device='cuda'): from . import perturb_settings import sys sys.path.append('resources/stylegan2-ada-pytorch') import legacy import dnnlib import click assert (domain == 'cifar10') network_pkl = { 'cifar10': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl', }[domain] device = torch.device(device) with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore G = G.eval() self.generator = G self.encoder = None # no pretrained encoder for cifar10 model self.device = device self.perturb_settings = perturb_settings.stylegan2_cc_settings[domain] self.pca_stats = None self.cc_mean_w = None # mean w latent per class
def convert(network_pkl, output_file): with dnnlib.util.open_url(network_pkl) as f: G_nvidia = legacy.load_network_pkl(f)["G_ema"] state_nv = G_nvidia.state_dict() n_mapping, n_layers = determine_config(state_nv) state_ros = {} for i in range(n_mapping): state_ros[f"style.{i+1}.weight"] = state_nv[f"mapping.fc{i}.weight"] state_ros[f"style.{i+1}.bias"] = state_nv[f"mapping.fc{i}.bias"] for i in range(int(n_layers)): if i > 0: for conv_level in range(2): convert_conv(state_ros, state_nv, f"convs.{2*i-2+conv_level}", f"synthesis.b{4*(2**i)}.conv{conv_level}") state_ros[f"noises.noise_{2*i-1+conv_level}"] = state_nv[f"synthesis.b{4*(2**i)}.conv{conv_level}.noise_const"].unsqueeze(0).unsqueeze(0) convert_to_rgb(state_ros, state_nv, f"to_rgbs.{i-1}", f"synthesis.b{4*(2**i)}") convert_blur_kernel(state_ros, state_nv, i-1) else: state_ros[f"input.input"] = state_nv[f"synthesis.b{4*(2**i)}.const"].unsqueeze(0) convert_conv(state_ros, state_nv, "conv1", f"synthesis.b{4*(2**i)}.conv1") state_ros[f"noises.noise_{2*i}"] = state_nv[f"synthesis.b{4*(2**i)}.conv1.noise_const"].unsqueeze(0).unsqueeze(0) convert_to_rgb(state_ros, state_nv, "to_rgb1", f"synthesis.b{4*(2**i)}") state_dict = {"g_ema": state_ros} torch.save(state_dict, output_file)
def __init__(self, model_path, stylegan_dir, truncation_psi=0.5): super().__init__() sys.path.insert(1, stylegan_dir) import legacy with open(model_path, 'rb') as f: self.G = legacy.load_network_pkl(f)['G_ema'].cuda().eval() self.truncation = truncation_psi
def _prepare(self) -> None: """Preparing :py:class:`StyleGAN` includes importing the required modules and loading model data. """ super()._prepare() # Step 1: Import modules from stylegan repository print("importing stylegan2ada from " f"'{config.nvlabs_stylegan2ada_pytorch_directory}'") nvlabs_directory = str(config.nvlabs_stylegan2ada_pytorch_directory) if nvlabs_directory not in sys.path: sys.path.insert(0, nvlabs_directory) global dnnlib import legacy # This requires torch 1.7.1 import dnnlib sys.path.remove(nvlabs_directory) self._device = torch.device('cuda') network_pkl = \ str(config.nvlabs_stylegan2ada_pytorch_directory / "metfaces.pkl") with dnnlib.util.open_url(network_pkl) as f: network_pkl = legacy.load_network_pkl(f) print(list(network_pkl)) # ['G', 'D', 'G_ema', 'training_set_kwargs', 'augment_pipe'] self._generator = network_pkl['G_ema'].to(self._device) self._discriminator = network_pkl['D'].to(self._device)
def run_projection(input_image): np.random.seed(seed) torch.manual_seed(seed) # Load networks. print('Loading networks from "%s"...' % checkpoint_path) device = torch.device('cuda') with dnnlib.util.open_url(checkpoint_path) as fp: G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore # Load target image. target_pil = PIL.Image.open(input_image).convert('RGB') w, h = target_pil.size s = min(w, h) target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)) target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS) target_uint8 = np.array(target_pil, dtype=np.uint8) # Optimize projection. projected_w_steps = project( G, target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable num_steps=num_steps, device=device, verbose=True ) os.makedirs(output_dir, exist_ok=True) projected_w = projected_w_steps[-1] file_name = input_image.split('/')[-1].split('.')[0] np.savez(f'{output_dir}/{file_name}.npz', w=projected_w.unsqueeze(0).cpu().numpy())
def generate_images(network_pkl: str, seeds: List[int], truncation_psi: float, noise_mode: str, outdir: str, translate: Tuple[float, float], rotate: float, class_idx: Optional[int]): """Generate images using pretrained network pickle. Examples: \b # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left). python gen_images.py --outdir=out --trunc=1 --seeds=2 \\ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl \b # Generate uncurated images with truncation using the MetFaces-U dataset python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl """ print('Loading networks from "%s"...' % network_pkl) device = torch.device('cpu') with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore os.makedirs(outdir, exist_ok=True) # Labels. label = torch.zeros([1, G.c_dim], device=device) if G.c_dim != 0: if class_idx is None: raise click.ClickException( 'Must specify class label with --class when using a conditional network' ) label[:, class_idx] = 1 else: if class_idx is not None: print( 'warn: --class=lbl ignored when running on an unconditional network' ) # Generate images. for seed_idx, seed in enumerate(seeds): print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) z = torch.from_numpy(np.random.RandomState(seed).randn( 1, G.z_dim)).to(device).float() # Construct an inverse rotation/translation matrix and pass to the generator. The # generator expects this matrix as an inverse to avoid potentially failing numerical # operations in the network. if hasattr(G.synthesis, 'input'): m = make_transform(translate, rotate) m = np.linalg.inv(m) G.synthesis.input.transform.copy_(torch.from_numpy(m)) img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to( torch.uint8) PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
def generate_gif(network_pkl: str, seed: int, num_rows: int, num_cols: int, resolution: int, num_phases: int, transition_frames: int, static_frames: int, truncation_psi: float, noise_mode: str, output: str): """Generate gif using pretrained network pickle. Examples: \b python generate_gif.py --output=obama.gif --seed=0 --num-rows=1 --num-cols=8 \\ --network=https://hanlab.mit.edu/projects/data-efficient-gans/models/DiffAugment-stylegan2-100-shot-obama.pkl """ print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda') with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore os.makedirs(os.path.dirname(output), exist_ok=True) np.random.seed(seed) output_seq = [] batch_size = num_rows * num_cols latent_size = G.z_dim latents = [ np.random.randn(batch_size, latent_size) for _ in range(num_phases) ] def to_image_grid(outputs): outputs = np.reshape(outputs, [num_rows, num_cols, *outputs.shape[1:]]) outputs = np.concatenate(outputs, axis=1) outputs = np.concatenate(outputs, axis=1) return Image.fromarray(outputs).resize( (resolution * num_cols, resolution * num_rows), Image.ANTIALIAS) def generate(dlatents): images = G.synthesis(dlatents, noise_mode=noise_mode) images = (images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to( torch.uint8).cpu().numpy() return to_image_grid(images) for i in range(num_phases): dlatents0 = G.mapping( torch.from_numpy(latents[i - 1]).to(device), None) dlatents1 = G.mapping(torch.from_numpy(latents[i]).to(device), None) for j in range(transition_frames): dlatents = (dlatents0 * (transition_frames - j) + dlatents1 * j) / transition_frames output_seq.append(generate(dlatents)) output_seq.extend([generate(dlatents1)] * static_frames) if not output.endswith('.gif'): output += '.gif' output_seq[0].save(output, save_all=True, append_images=output_seq[1:], optimize=False, duration=50, loop=0)
def run_projection( network_pkl: str, target_fname: str, outdir: str, save_video: bool, seed: int, num_steps: int ): """Project given image to the latent space of pretrained network pickle. Examples: \b python projector.py --outdir=out --target=~/mytargetimg.png \\ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl """ np.random.seed(seed) torch.manual_seed(seed) # Load networks. print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda') with dnnlib.util.open_url(network_pkl) as fp: G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore # Load target image. target_pil = PIL.Image.open(os.path.expanduser(target_fname)).convert('RGB') w, h = target_pil.size s = min(w, h) target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)) target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS) target_uint8 = np.array(target_pil, dtype=np.uint8) # Optimize projection. start_time = perf_counter() projected_w_steps = project( G, target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable num_steps=num_steps, device=device, verbose=True ) print(f'Elapsed: {(perf_counter()-start_time):.1f} s') # Render debug output: optional video and projected image and W vector. os.makedirs(outdir, exist_ok=True) # Save final projected frame and W vector. phoneme_characters = target_fname.split('_')[-1][:-4] # target_pil.save(f'{outdir}/target_{phoneme_characters}.png') projected_w = projected_w_steps[-1] synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') synth_image = (synth_image + 1) * (255/2) synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj_{phoneme_characters}.png') # np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy()) ls_v = projected_w.unsqueeze(0).cpu().numpy()[:, 1, :].flatten() return ls_v
def generate_images( ctx: click.Context, network_pkl: str, truncation_psi: float, external_truncation_psi: float, noise_mode: str, num_images: int, batch_size: int, seed: int, outdir: str, ): print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda') with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to( device).eval() # type: ignore os.makedirs(outdir, exist_ok=True) if external_truncation_psi < 1: z = torch.randn(10000, G.z_dim, device=device) # [10000, z_dim] w = G.mapping(z, None) # [10000, num_ws, w_dim] w_avg_ext = w[:, 0].mean(dim=0, keepdim=True).unsqueeze(1) # [1, 1, w_dim] random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # Generate images. for batch_idx in tqdm(range((num_images + batch_size - 1) // batch_size)): z = torch.randn(batch_size, G.z_dim, device=device) # [batch_size, z_dim] w = G.mapping( z, None, truncation_psi=truncation_psi) # [batch_size, num_ws, z_dim] if external_truncation_psi < 1: w = (1 - external_truncation_psi ) * w_avg_ext + external_truncation_psi * w imgs = G.synthesis(w, noise_mode=noise_mode) # z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) # img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) imgs = (imgs.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to( torch.uint8) for i, img in enumerate(imgs): image_num = batch_idx * batch_size + i if image_num >= num_images: break PIL.Image.fromarray( img.cpu().numpy(), 'RGB').save(f'{outdir}/img_{image_num:06d}.png')
def generate_images( network_pkl: str, latent_vectors: Optional[List[int]], truncation_psi: float, noise_mode: str, outdir: str, class_idx: Optional[int], ): [ os.remove(outdir + file) for file in os.listdir(outdir) if file.endswith('.png') ] print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda') with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore os.makedirs(outdir, exist_ok=True) label = torch.zeros([1, G.c_dim], device=device) if G.c_dim != 0: label[:, class_idx] = 1 else: if class_idx is not None: print( 'warn: --class=lbl ignored when running on an unconditional network' ) counter = 0 for latent_vector in latent_vectors: # Generate images. input = np.array(latent_vector).reshape(1, latent_vector.shape[0]) v = input.flatten() for i in range(15): v = np.vstack([v, input.flatten()]) v = v.reshape(1, v.shape[0], v.shape[1]) z = torch.from_numpy(v).to(device) img = G.synthesis(z, noise_mode=noise_mode) img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to( torch.uint8) PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{counter:04d}.png') # z = torch.from_numpy(input).to(device) # img = G(z, label, noise_mode=noise_mode) # truncation_psi=truncation_psi, # img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) # PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{counter}2.png') counter += 1
def generate_images( network_pkl: str, seeds: List[int], shuffle_seed: Optional[int], truncation_psi: float, grid: Tuple[int, int], num_keyframes: Optional[int], w_frames: int, output: str, class_idx: Optional[int], ): """Render a latent vector interpolation video. Examples: \b # Render a 4x2 grid of interpolations for seeds 0 through 31. python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl Animation length and seed keyframes: The animation length is either determined based on the --seeds value or explicitly specified using the --num-keyframes option. When num keyframes is specified with --num-keyframes, the output video length will be 'num_keyframes*w_frames' frames. If --num-keyframes is not specified, the number of seeds given with --seeds must be divisible by grid size W*H (--grid). In this case the output video length will be '# seeds/(w*h)*w_frames' frames. """ print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda') with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi, class_idx=class_idx)
def main(pkl: str, psi: float, radius_large: float, radius_small: float, step1: float, step2: float, seed: Optional[int], video_length: float = 1.0, size: int = None, seeds: int = None, scale_type: str = 'pad'): if (size): print('render custom size: ', size) print('padding method:', scale_type) custom = True else: custom = False G_kwargs = dnnlib.EasyDict() G_kwargs.size = size G_kwargs.scale_type = scale_type print('Loading networks from "%s"...' % pkl) device = torch.device('cuda') with dnnlib.util.open_url(pkl) as f: G = legacy.load_network_pkl(f, custom=custom, **G_kwargs)['G_ema'].to( device) # type: ignore frames = generate_from_generator_adaptive(psi, radius_large, radius_small, step1, step2, video_length, seed, seeds, G, device) frames = moviepy.editor.ImageSequenceClip(frames, fps=30) # Generate video at the current date and timestamp timestamp = datetime.now().strftime("%d-%m-%Y-%I-%M-%S-%p") mp4_file = './circular-' + timestamp + '.mp4' mp4_codec = 'libx264' mp4_bitrate = '15M' mp4_fps = 24 # 20 frames.write_videofile(mp4_file, fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate)
def generate_video(network_pkl, seeds, steps, fps, output_filename, outdir): if len(seeds) < 2: print('Need more than one seed') return if steps < 1: print('At least one step required') return print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda') with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore label = torch.zeros([1, G.c_dim], device=device) os.makedirs(outdir, exist_ok=True) # Generate the images for the video. idx = 0 for i in range(len(seeds) - 1): v1 = seed2vec(G, seeds[i]) v2 = seed2vec(G, seeds[i + 1]) diff = v2 - v1 step = diff / steps current = v1.copy() for j in tqdm(range(steps), desc=f"Seed {seeds[i]}"): current = current + step z = torch.from_numpy(current).to(device) img = G(z, label, noise_mode='const') img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to( torch.uint8) PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save( os.path.join(outdir, f'frame-{idx}.png')) idx += 1 cmd = 'ffmpeg -y -i {}/frame-%d.png -r {} -c:v libx264 -crf 18 -preset medium -pix_fmt yuv420p {}'.format( outdir, fps, output_filename) os.system(cmd)
def generate_video(network_pkl, seed, fps, output_filename, wav_filename, outdir): print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda') with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore label = torch.zeros([1, G.c_dim], device=device) os.makedirs(outdir, exist_ok=True) beats, total_len = get_beats(wav_filename) beats = [int(np.round(fps * x)) for x in beats[::1]] # Add the first and the last frames to beats array if 0 not in beats: beats = [0] + beats last = int(np.round(fps * total_len)) if last not in beats: beats.append(last) idx = 0 for i in tqdm(range(len(beats)-1)): v1 = seed2vec(G, seed) seed += 1 v2 = seed2vec(G, seed) diff = v2 - v1 n_frames = beats[i+1] - beats[i] x = np.linspace(0, np.pi, n_frames) y = np.cumsum(1 - np.sin(x)) y /= y[-1] for j in range(n_frames): current = v1 + diff * y[j] z = torch.from_numpy(current).to(device) img = G(z, label, noise_mode='const') img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(os.path.join(outdir, 'frame-{}.png'.format(beats[i] + j))) cmd = 'ffmpeg -y -i {}/frame-%d.png -i {} -r {} -c:v libx264 -crf 18 -preset medium -pix_fmt yuv420p {}'.format(outdir, wav_filename, fps, output_filename) os.system(cmd)
def run_approach( network_pkl: str, outdir: str, save_video: bool, save_ws: bool, seed: int, num_steps: int, text: str, lr: float, inf: float, nf: float, w: str, psi: float, noise_opt: bool ): """Descend on StyleGAN2 w vector value using CLIP, tuning an image with given text prompt. Example: \b python3 approach.py --network network-snapshot-ffhq.pkl --outdir project --num-steps 100 \\ --text 'an image of a girl with a face resembling Paul Krugman' --psi 0.8 --seed 12345 """ #seed = 1 np.random.seed(1) torch.manual_seed(1) local_args = dict(locals()) params = [] for x in local_args: #if x != 'G' and x != 'device': #print(x,':',local_args[x]) params.append({x:local_args[x]}) #print(json.dumps(params)) hashname = str(hashlib.sha1((json.dumps(params)).encode('utf-16be')).hexdigest() ) print('run hash', hashname) ws = None if w is not None: print ('loading w from file', w, 'ignoring seed and psi') ws = np.load(w)['w'] # take off print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda') with dnnlib.util.open_url(network_pkl) as fp: G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore # approach projected_w_steps = approach( G, num_steps=num_steps, device=device, initial_learning_rate = lr, psi = psi, seed = seed, initial_noise_factor = inf, noise_floor = nf, text = text, ws = ws, noise_opt = noise_opt ) # save video os.makedirs(outdir, exist_ok=True) if save_video: video = imageio.get_writer(f'{outdir}/out-{hashname}.mp4', mode='I', fps=10, codec='libx264', bitrate='16M') print (f'Saving optimization progress video "{outdir}/out-{hashname}.mp4"') for projected_w in projected_w_steps: synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') synth_image = (synth_image + 1) * (255/2) synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() video.append_data(np.concatenate([synth_image], axis=1)) video.close() # save ws if save_ws: print ('Saving optimization progress ws') step = 0 for projected_w in projected_w_steps: np.savez(f'{outdir}/w-{hashname}-{step}.npz', w=projected_w.unsqueeze(0).cpu().numpy()) step+=1 # save the result and the final w print ('Saving finals') projected_w = projected_w_steps[-1] synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') synth_image = (synth_image + 1) * (255/2) synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/out-{hashname}.png') np.savez(f'{outdir}/w-{hashname}-final.npz', w=projected_w.unsqueeze(0).cpu().numpy()) # save params with open(f'{outdir}/params-{hashname}.txt', 'w') as outfile: json.dump(params, outfile)
device = torch.device('cuda') eigvec = torch.load(args.factor)["eigvec"].to(device) index = args.index seeds = args.seeds custom = False G_kwargs = dnnlib.EasyDict() G_kwargs.size = None G_kwargs.scale_type = 'symm' print('Loading networks from "%s"...' % args.ckpt) device = torch.device('cuda') with dnnlib.util.open_url(args.ckpt) as f: G = legacy.load_network_pkl(f, custom=custom, **G_kwargs)['G_ema'].to(device) # type: ignore if not os.path.exists(args.output): os.makedirs(args.output) label = torch.zeros([1, G.c_dim], device=device) # assume no class label noise_mode = "const" # default truncation_psi = args.truncation latents = [] mode = "random" log_str = "" index_list_of_eigenvalues = []
def generate(): os.makedirs(a.out_dir, exist_ok=True) np.random.seed(seed=696) device = torch.device('cuda') # setup generator Gs_kwargs = dnnlib.EasyDict() Gs_kwargs.verbose = a.verbose Gs_kwargs.size = a.size Gs_kwargs.scale_type = a.scale_type # mask/blend latents with external latmask or by splitting the frame if a.latmask is None: nHW = [int(s) for s in a.nXY.split('-')][::-1] assert len(nHW) == 2, ' Wrong count nXY: %d (must be 2)' % len(nHW) n_mult = nHW[0] * nHW[1] if a.verbose is True and n_mult > 1: print(' Latent blending w/split frame %d x %d' % (nHW[1], nHW[0])) lmask = np.tile(np.asarray([[[[1]]]]), (1, n_mult, 1, 1)) Gs_kwargs.countHW = nHW Gs_kwargs.splitfine = a.splitfine else: if a.verbose is True: print(' Latent blending with mask', a.latmask) n_mult = 2 if os.path.isfile(a.latmask): # single file lmask = np.asarray([[img_read(a.latmask)[:, :, 0] / 255.] ]) # [h,w] elif os.path.isdir(a.latmask): # directory with frame sequence lmask = np.asarray([[ img_read(f)[:, :, 0] / 255. for f in img_list(a.latmask) ]]) # [h,w] else: print(' !! Blending mask not found:', a.latmask) exit(1) lmask = np.concatenate((lmask, 1 - lmask), 1) # [frm,2,h,w] lmask = torch.from_numpy(lmask).to(device) # load base or custom network pkl_name = osp.splitext(a.model)[0] if '.pkl' in a.model.lower(): custom = False print(' .. Gs from pkl ..', basename(a.model)) else: custom = True print(' .. Gs custom ..', basename(a.model)) with dnnlib.util.open_url(pkl_name + '.pkl') as f: Gs = legacy.load_network_pkl(f, custom=custom, **Gs_kwargs)['G_ema'].to( device) # type: ignore if a.verbose is True: print(' out shape', Gs.output_shape[1:]) if a.verbose is True: print(' making timeline..') lats = [] # list of [frm,1,512] for i in range(n_mult): lat_tmp = latent_anima((1, Gs.z_dim), a.frames, a.fstep, cubic=a.cubic, gauss=a.gauss, verbose=False) # [frm,1,512] lats.append(lat_tmp) # list of [frm,1,512] latents = np.concatenate(lats, 1) # [frm,X,512] print(' latents', latents.shape) latents = torch.from_numpy(latents).to(device) frame_count = latents.shape[0] # distort image by tweaking initial const layer if a.digress > 0: try: init_res = Gs.init_res except: init_res = (4, 4) # default initial layer size dconst = [] for i in range(n_mult): dc_tmp = a.digress * latent_anima([1, Gs.z_dim, *init_res], a.frames, a.fstep, cubic=True, verbose=False) dconst.append(dc_tmp) dconst = np.concatenate(dconst, 1) else: dconst = np.zeros([frame_count, 1, 1, 1, 1]) dconst = torch.from_numpy(dconst).to(device) # labels / conditions label_size = Gs.c_dim if label_size > 0: labels = torch.zeros((frame_count, n_mult, label_size), device=device) # [frm,X,lbl] if a.labels is None: label_ids = [] for i in range(n_mult): label_ids.append(random.randint(0, label_size - 1)) else: label_ids = [int(x) for x in a.labels.split('-')] label_ids = label_ids[:n_mult] # ensure we have enough labels for i, l in enumerate(label_ids): labels[:, i, l] = 1 else: labels = [None] # generate images from latent timeline pbar = ProgressBar(frame_count) for i in range(frame_count): latent = latents[i] # [X,512] label = labels[i % len(labels)] latmask = lmask[i % len(lmask)] if lmask is not None else [None] # [X,h,w] dc = dconst[i % len(dconst)] # [X,512,4,4] # generate multi-latent result if custom: output = Gs(latent, label, latmask, dc, truncation_psi=a.trunc, noise_mode='const') else: output = Gs(latent, label, truncation_psi=a.trunc, noise_mode='const') output = (output.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to( torch.uint8).cpu().numpy() # save image ext = 'png' if output.shape[3] == 4 else 'jpg' filename = osp.join(a.out_dir, "%06d.%s" % (i, ext)) imsave(filename, output[0]) pbar.upd() # convert latents to dlatents, save them if a.save_lat is True: latents = latents.squeeze(1) # [frm,512] dlatents = Gs.mapping(latents, label) # [frm,18,512] if a.size is None: a.size = [''] * 2 filename = '{}-{}-{}.npy'.format(basename(a.model), a.size[1], a.size[0]) filename = osp.join(osp.dirname(a.out_dir), filename) dlatents = dlatents.cpu().numpy() np.save(filename, dlatents) print('saved dlatents', dlatents.shape, 'to', filename)
def main(): os.makedirs(a.out_dir, exist_ok=True) device = torch.device('cuda') # setup generator Gs_kwargs = dnnlib.EasyDict() Gs_kwargs.verbose = a.verbose Gs_kwargs.size = a.size Gs_kwargs.scale_type = a.scale_type # load base or custom network pkl_name = osp.splitext(a.model)[0] if '.pkl' in a.model.lower(): custom = False print(' .. Gs from pkl ..', basename(a.model)) else: custom = True print(' .. Gs custom ..', basename(a.model)) with dnnlib.util.open_url(pkl_name + '.pkl') as f: Gs = legacy.load_network_pkl(f, custom=custom, **Gs_kwargs)['G_ema'].to( device) # type: ignore dlat_shape = (1, Gs.num_ws, Gs.w_dim) # [1,18,512] # read saved latents if a.dlatents is not None and osp.isfile(a.dlatents): key_dlatents = load_latents(a.dlatents) if len(key_dlatents.shape) == 2: key_dlatents = np.expand_dims(key_dlatents, 0) elif a.dlatents is not None and osp.isdir(a.dlatents): # if a.dlatents.endswith('/') or a.dlatents.endswith('\\'): a.dlatents = a.dlatents[:-1] key_dlatents = [] npy_list = file_list(a.dlatents, 'npy') for npy in npy_list: key_dlatent = load_latents(npy) if len(key_dlatent.shape) == 2: key_dlatent = np.expand_dims(key_dlatent, 0) key_dlatents.append(key_dlatent) key_dlatents = np.concatenate(key_dlatents) # [frm,18,512] else: print(' No input dlatents found') exit() key_dlatents = key_dlatents[:, np.newaxis] # [frm,1,18,512] print(' key dlatents', key_dlatents.shape) # replace higher layers with single (style) latent if a.style_dlat is not None: print(' styling with dlatent', a.style_dlat) style_dlatent = load_latents(a.style_dlat) while len(style_dlatent.shape) < 4: style_dlatent = np.expand_dims(style_dlatent, 0) # try replacing 5 by other value, less than Gs.num_ws key_dlatents[:, :, range(5, Gs.num_ws ), :] = style_dlatent[:, :, range(5, Gs.num_ws), :] frames = key_dlatents.shape[0] * a.fstep dlatents = latent_anima(dlat_shape, frames, a.fstep, key_latents=key_dlatents, cubic=a.cubic, verbose=True) # [frm,1,512] print(' dlatents', dlatents.shape) frame_count = dlatents.shape[0] dlatents = torch.from_numpy(dlatents).to(device) # distort image by tweaking initial const layer if a.digress > 0: try: init_res = Gs.init_res except Exception: init_res = (4, 4) # default initial layer size dconst = a.digress * latent_anima([1, Gs.z_dim, *init_res], frame_count, a.fstep, cubic=True, verbose=False) else: dconst = np.zeros([frame_count, 1, 1, 1, 1]) dconst = torch.from_numpy(dconst).to(device) # generate images from latent timeline pbar = ProgressBar(frame_count) for i in range(frame_count): # generate multi-latent result if custom: output = Gs.synthesis(dlatents[i], None, dconst[i], noise_mode='const') else: output = Gs.synthesis(dlatents[i], noise_mode='const') output = (output.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to( torch.uint8).cpu().numpy() ext = 'png' if output.shape[3] == 4 else 'jpg' filename = osp.join(a.out_dir, "%06d.%s" % (i, ext)) imsave(filename, output[0]) pbar.upd()
def generate_images( ctx: click.Context, easing: str, interpolation: str, increment: Optional[float], network_pkl: str, process: str, random_seed: Optional[int], diameter: Optional[float], seeds: Optional[List[int]], space: str, fps: Optional[int], frames: Optional[int], truncation_psi: float, noise_mode: str, outdir: str, class_idx: Optional[int], projected_w: Optional[str], start: Optional[float], stop: Optional[float], ): """Generate images using pretrained network pickle. Examples: \b # Generate curated MetFaces images without truncation (Fig.10 left) python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl \b # Generate uncurated MetFaces images with truncation (Fig.12 upper left) python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl \b # Generate class conditional CIFAR-10 images (Fig.17 left, Car) python generate.py --outdir=out --seeds=0-35 --class=1 \\ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl \b # Render an image from projected W python generate.py --outdir=out --projected_w=projected_w.npz \\ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl """ print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda') with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore os.makedirs(outdir, exist_ok=True) # Synthesize the result of a W projection. if (process == 'image') and projected_w is not None: if seeds is not None: print('Warning: --seeds is ignored when using --projected-w') print(f'Generating images from projected W "{projected_w}"') ws = np.load(projected_w)['w'] ws = torch.tensor(ws, device=device) # pylint: disable=not-callable assert ws.shape[1:] == (G.num_ws, G.w_dim) for idx, w in enumerate(ws): img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode) img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to( torch.uint8) img = PIL.Image.fromarray( img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.png') return # Labels. label = torch.zeros([1, G.c_dim], device=device) if G.c_dim != 0: if class_idx is None: ctx.fail( 'Must specify class label with --class when using a conditional network' ) label[:, class_idx] = 1 else: if class_idx is not None: print( 'warn: --class=lbl ignored when running on an unconditional network' ) if (process == 'image'): if seeds is None: ctx.fail('--seeds option is required when not using --projected-w') # Generate images. for seed_idx, seed in enumerate(seeds): print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) z = torch.from_numpy( np.random.RandomState(seed).randn(1, G.z_dim)).to(device) img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to( torch.uint8) PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png') elif (process == 'interpolation'): # create path for frames dirpath = os.path.join(outdir, 'frames') os.makedirs(dirpath, exist_ok=True) # autogenerate video name: not great! if seeds is not None: seedstr = '_'.join([str(seed) for seed in seeds]) vidname = f'{process}-{interpolation}-seeds_{seedstr}-{fps}fps' elif (interpolation == 'noiseloop' or 'circularloop'): vidname = f'{process}-{interpolation}-{diameter}dia-seed_{random_seed}-{fps}fps' interpolate(G, device, projected_w, seeds, random_seed, space, truncation_psi, label, frames, noise_mode, dirpath, interpolation, easing, diameter) # convert to video cmd = f'ffmpeg -y -r {fps} -i {dirpath}/frame%04d.png -vcodec libx264 -pix_fmt yuv420p {outdir}/{vidname}.mp4' subprocess.call(cmd, shell=True) elif (process == 'truncation'): if seeds is None or (len(seeds) > 1): ctx.fail('truncation requires a single seed value') # create path for frames dirpath = os.path.join(outdir, 'frames') os.makedirs(dirpath, exist_ok=True) #vidname seed = seeds[0] vidname = f'{process}-seed_{seed}-start_{start}-stop_{stop}-inc_{increment}-{fps}fps' # generate frames truncation_traversal(G, device, seeds, label, start, stop, increment, noise_mode, dirpath) # convert to video cmd = f'ffmpeg -y -r {fps} -i {dirpath}/frame%04d.png -vcodec libx264 -pix_fmt yuv420p {outdir}/{vidname}.mp4' subprocess.call(cmd, shell=True)
def run_projection( network_pkl: str, target_fname: str, outdir: str, save_video: bool, seed: int, num_steps: int ): """Project given image to the latent space of pretrained network pickle. Examples: \b python projector.py --outdir=out --target=~/mytargetimg.png \\ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl """ np.random.seed(seed) torch.manual_seed(seed) # Load networks. print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda') with dnnlib.util.open_url(network_pkl) as fp: G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore # Load target image. target_pil = PIL.Image.open(target_fname).convert('RGB') w, h = target_pil.size s = min(w, h) target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)) target_pil = target_pil.resize((G.img_resolution_w, G.img_resolution_h), PIL.Image.LANCZOS) target_uint8 = np.array(target_pil, dtype=np.uint8) # Optimize projection. start_time = perf_counter() projected_w_steps = project( G, target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable num_steps=num_steps, device=device, verbose=True ) print (f'Elapsed: {(perf_counter()-start_time):.1f} s') # Render debug output: optional video and projected image and W vector. os.makedirs(outdir, exist_ok=True) if save_video: video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M') print (f'Saving optimization progress video "{outdir}/proj.mp4"') for projected_w in projected_w_steps: synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') synth_image = (synth_image + 1) * (255/2) synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() video.append_data(np.concatenate([target_uint8, synth_image], axis=1)) video.close() # Save final projected frame and W vector. target_pil.save(f'{outdir}/target.png') projected_w = projected_w_steps[-1] synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') synth_image = (synth_image + 1) * (255/2) synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png') np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())
def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose): """Calculate quality metrics for previous training run or pretrained network pickle. Examples: \b # Previous training run: look up options automatically, save result to JSONL file. python calc_metrics.py --metrics=pr50k3_full \\ --network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl \b # Pre-trained network pickle: specify dataset explicitly, print result to stdout. python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl Available metrics: \b ADA paper: fid50k_full Frechet inception distance against the full dataset. kid50k_full Kernel inception distance against the full dataset. pr50k3_full Precision and recall againt the full dataset. is50k Inception score for CIFAR-10. \b StyleGAN and StyleGAN2 papers: fid50k Frechet inception distance against 50k real images. kid50k Kernel inception distance against 50k real images. pr50k3 Precision and recall against 50k real images. ppl2_wend Perceptual path length in W at path endpoints against full image. ppl_zfull Perceptual path length in Z for full paths against cropped image. ppl_wfull Perceptual path length in W for full paths against cropped image. ppl_zend Perceptual path length in Z at path endpoints against cropped image. ppl_wend Perceptual path length in W at path endpoints against cropped image. """ dnnlib.util.Logger(should_flush=True) # Validate arguments. args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose) if not all(metric_main.is_valid_metric(metric) for metric in args.metrics): ctx.fail( '\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) if not args.num_gpus >= 1: ctx.fail('--gpus must be at least 1') # Load network. if not dnnlib.util.is_url( network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): ctx.fail('--network must point to a file or URL') if args.verbose: print(f'Loading network from "{network_pkl}"...') with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f: network_dict = legacy.load_network_pkl(f) args.G = network_dict['G_ema'] # subclass of torch.nn.Module # Initialize dataset options. if data is not None: args.dataset_kwargs = dnnlib.EasyDict( class_name='training.dataset.ImageFolderDataset', path=data) elif network_dict['training_set_kwargs'] is not None: args.dataset_kwargs = dnnlib.EasyDict( network_dict['training_set_kwargs']) else: ctx.fail('Could not look up dataset options; please specify --data') # Finalize dataset options. args.dataset_kwargs.resolution_h = args.G.img_resolution_h args.dataset_kwargs.resolution_w = args.G.img_resolution_w args.dataset_kwargs.use_labels = (args.G.c_dim != 0) if mirror is not None: args.dataset_kwargs.xflip = mirror # Print dataset options. if args.verbose: print('Dataset options:') print(json.dumps(args.dataset_kwargs, indent=2)) # Locate run dir. args.run_dir = None if os.path.isfile(network_pkl): pkl_dir = os.path.dirname(network_pkl) if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')): args.run_dir = pkl_dir # Launch processes. if args.verbose: print('Launching processes...') torch.multiprocessing.set_start_method('spawn') with tempfile.TemporaryDirectory() as temp_dir: if args.num_gpus == 1: subprocess_fn(rank=0, args=args, temp_dir=temp_dir) else: torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
def generate_interpolation_video(network_pkl=None, grid_size=[1, 1], png_sequence=False, image_zoom=1, duration_sec=60.0, smoothing_sec=1.0, truncation_psi=1, noise_mode=False, filename=None, mp4_fps=30, mp4_codec='libx264', mp4_bitrate='16M', random_seed=1000, outdir='./videos'): if network_pkl == None: print('ERROR: Please enter pkl path.') sys.exit(1) num_frames = int(np.rint(duration_sec * mp4_fps)) random_state = np.random.RandomState(random_seed) if filename is None: filename = get_id_string_for_network_pkl(network_pkl) + '-seed-' + str( random_seed) print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda') with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore print('Generating latent vectors...') shape = [num_frames, np.prod(grid_size) ] + [G.z_dim] # [frame, image, channel, component] print(shape) all_latents = random_state.randn(*shape).astype(np.float32) all_latents = scipy.ndimage.gaussian_filter( all_latents, [smoothing_sec * mp4_fps] + [0] * (len(shape) - 1), mode='wrap') all_latents /= np.sqrt(np.mean(np.square(all_latents))) print("Rendering...\ntruncation_psi =", truncation_psi, ", noise_mode =", noise_mode) os.makedirs(outdir, exist_ok=True) ### ### this is the moviepy implementation of rendering ### it has a nice progress bar ### there is an imageio implementation commented out below as well ### # Frame generation func for moviepy. # def make_frame(t): # frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1)) # z = torch.from_numpy(all_latents[frame_idx]).to(device) # label = np.zeros([z.shape[0], 0], np.float32) # images = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) # images = (images * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy() # grid = create_image_grid(images, grid_size) # if image_zoom > 1: # grid = scipy.ndimage.zoom(grid, [image_zoom, image_zoom, 1], order=0) # if grid.shape[2] == 1: # grid = grid.repeat(3, 2) # grayscale => RGB # return grid # # Generate video. # import moviepy.editor # pip install moviepy # moviepy.editor.VideoClip(make_frame, duration=duration_sec).write_videofile(os.path.join(outdir, filename + ".mp4"), fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate) ### ### this is an alternative imageio implementation of rendering ### I like moviepy more cause it has a nice progress bar ### not sure which one is more "performant", feel free to experiment ### import imageio # pip install imageio video = imageio.get_writer(f'{outdir}/seed{random_seed:04d}.mp4', mode='I', fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate) for frame_idx in range(num_frames): z = torch.from_numpy(all_latents[frame_idx]).to(device) label = torch.zeros([1, G.c_dim], device=device) img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy() grid = create_image_grid(img, grid_size) if image_zoom > 1: grid = scipy.ndimage.zoom(grid, [image_zoom, image_zoom, 1], order=0) video.append_data(grid) video.close()
def training_loop( run_dir='.', # Output directory. training_set_kwargs={}, # Options for training set. data_loader_kwargs={}, # Options for torch.utils.data.DataLoader. G_kwargs={}, # Options for generator network. D_kwargs={}, # Options for discriminator network. G_opt_kwargs={}, # Options for generator optimizer. D_opt_kwargs={}, # Options for discriminator optimizer. augment_kwargs=None, # Options for augmentation pipeline. None = disable. loss_kwargs={}, # Options for loss function. savenames=None, # random_seed=0, # Global random seed. num_gpus=1, # Number of GPUs participating in the training. rank=0, # Rank of the current process in [0, num_gpus[. batch_size=4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. batch_gpu=4, # Number of samples processed at a time by one GPU. ema_kimg=10, # Half-life of the exponential moving average (EMA) of generator weights. ema_rampup=None, # EMA ramp-up coefficient. G_reg_interval=4, # How often to perform regularization for G? None = disable lazy regularization. D_reg_interval=16, # How often to perform regularization for D? None = disable lazy regularization. augment_p=0, # Initial value of augmentation probability. ada_target=None, # ADA target value. None = fixed p. ada_interval=4, # How often to perform ADA adjustment? ada_kimg=500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit. total_kimg=25000, # Total length of the training, measured in thousands of real images. kimg_per_tick=4, # Progress snapshot interval. image_snapshot_ticks=50, # How often to save image snapshots? None = disable. network_snapshot_ticks=50, # How often to save network snapshots? None = disable. resume_pkl=None, # Network pickle to resume training from. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. cudnn_benchmark=True, # Enable torch.backends.cudnn.benchmark? abort_fn=None, # Callback function for determining whether to abort training. Must return consistent results across ranks. progress_fn=None, # Callback function for updating training progress. Called for all ranks. **kwargs ): # Initialize. start_time = time.time() device = torch.device('cuda', rank) np.random.seed(random_seed * num_gpus + rank) torch.manual_seed(random_seed * num_gpus + rank) torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. conv2d_gradfix.enabled = True # Improves training speed. grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. run_id = str(int(time.time()))[3:-2] hostname = str(socket.gethostname()) # Load training set. # if rank == 0: # print('Loading training set...') training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size // num_gpus, **data_loader_kwargs)) if rank == 0: # print('Num images: ', len(training_set)) print('Image shape:', training_set.image_shape) print('Label shape:', training_set.label_shape) ext = 'png' if training_set.image_shape[0] == 4 else 'jpg' # !!! # Construct networks. if rank == 0: print('Constructing networks...') common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels) G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module G_ema = copy.deepcopy(G).eval() # Resume from existing pickle. if (resume_pkl is not None) and (rank == 0): # !!! if os.path.isdir(resume_pkl): resume_pkl, resume_kimg = locate_latest_pkl(resume_pkl) print(' Resuming from "%s", kimg %.3g' % (resume_pkl, resume_kimg)) with dnnlib.util.open_url(resume_pkl) as f: resume_data = legacy.load_network_pkl(f) for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: misc.copy_params_and_buffers(resume_data[name], module, require_all=False) # Print network summary tables. if rank == 0: # z = torch.empty([batch_gpu, G.z_dim], device=device) # c = torch.empty([batch_gpu, G.c_dim], device=device) # img = misc.print_module_summary(G, [z, c]) # misc.print_module_summary(D, [img, c]) pass # Setup augmentation. # if rank == 0: # print('Setting up augmentation...') augment_pipe = None ada_stats = None if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None): augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module augment_pipe.p.copy_(torch.as_tensor(augment_p)) if ada_target is not None: ada_stats = training_stats.Collector(regex='Loss/signs/real') # Distribute across GPUs. # if rank == 0: # print(f'Distributing across {num_gpus} GPUs...') ddp_modules = dict() for name, module in [('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema), ('augment_pipe', augment_pipe)]: if (num_gpus > 1) and (module is not None) and len(list(module.parameters())) != 0: module.requires_grad_(True) module = torch.nn.parallel.DistributedDataParallel(module, device_ids=[device], broadcast_buffers=False) module.requires_grad_(False) if name is not None: ddp_modules[name] = module # Setup training phases. # if rank == 0: # print('Setting up training phases...') loss = dnnlib.util.construct_class_by_name(device=device, **ddp_modules, **loss_kwargs) # subclass of training.loss.Loss phases = [] for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]: if reg_interval is None: opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [dnnlib.EasyDict(name=name + 'both', module=module, opt=opt, interval=1)] else: # Lazy regularization. mb_ratio = reg_interval / (reg_interval + 1) opt_kwargs = dnnlib.EasyDict(opt_kwargs) opt_kwargs.lr = opt_kwargs.lr * mb_ratio opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas] opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [dnnlib.EasyDict(name=name + 'main', module=module, opt=opt, interval=1)] phases += [dnnlib.EasyDict(name=name + 'reg', module=module, opt=opt, interval=reg_interval)] for phase in phases: phase.start_event = None phase.end_event = None if rank == 0: phase.start_event = torch.cuda.Event(enable_timing=True) phase.end_event = torch.cuda.Event(enable_timing=True) # Export sample images. grid_size = None grid_z = None grid_c = None if rank == 0: # print('Exporting sample images...') grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set) save_image_grid(images, os.path.join(run_dir, '_reals.' + ext), drange=[0, 255], grid_size=grid_size) grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) grid_c = torch.from_numpy(labels).to(device).split(batch_gpu) # Initialize logs. # if rank == 0: # print('Initializing logs...') stats_collector = training_stats.Collector(regex='.*') # Only upload data once, not for every process if rank == 0: url = f'https://mnr6yzqr22jgywm-adw2.adb.eu-frankfurt-1.oraclecloudapps.com/ords/thesisproject/sg_d/' \ f'report/{run_id}/{hostname}' metadata = training_set_kwargs more_metadata = [ {'run_dir': run_dir}, {'data_loader_kwargs': data_loader_kwargs}, {'G_kwargs': G_kwargs}, {'D_kwargs': D_kwargs}, {'G_opt_kwargs': G_opt_kwargs}, {'D_opt_kwargs': D_opt_kwargs}, {'augment_kwargs': augment_kwargs}, {'loss_kwargs': loss_kwargs}, {'savenames': savenames}, {'random_seed': random_seed}, {'num_gpus': num_gpus}, {'batch_size': batch_size}, {'batch_gpu': batch_gpu}, {'ema_kimg': ema_kimg}, {'ema_rampup': ema_rampup}, {'G_reg_interval': G_reg_interval}, {'D_reg_interval': D_reg_interval}, {'augment_p': augment_p}, {'ada_target': ada_target}, {'ada_interval': ada_interval}, {'ada_kimg': ada_kimg}, {'total_kimg': total_kimg}, {'kimg_per_tick': kimg_per_tick}, {'image_snapshot_ticks': image_snapshot_ticks}, {'network_snapshot_ticks': network_snapshot_ticks}, {'resume_pkl': resume_pkl}, {'resume_kimg': resume_kimg}, {'cudnn_benchmark': cudnn_benchmark} ] for d in more_metadata: metadata.update(d) requests.post(url, data=json.dumps(training_set_kwargs)) stats_jsonl = None # stats_tfevents = None if rank == 0: stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt') # Train. if rank == 0: print(f'Training for {total_kimg} kimg...') print() cur_nimg = 0 cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() # maintenance_time = tick_start_time - start_time batch_idx = 0 if progress_fn is not None: progress_fn(0, total_kimg) while True: # Fetch training data. with torch.autograd.profiler.record_function('data_fetch'): phase_real_img, phase_real_c = next(training_set_iterator) phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) phase_real_c = phase_real_c.to(device).split(batch_gpu) all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device) all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)] all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size)] all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device) all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)] # Execute training phases. for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c): if batch_idx % phase.interval != 0: continue # Initialize gradient accumulation. if phase.start_event is not None: phase.start_event.record(torch.cuda.current_stream(device)) phase.opt.zero_grad(set_to_none=True) phase.module.requires_grad_(True) # Accumulate gradients over multiple rounds. for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate( zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c)): sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1) gain = phase.interval loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, sync=sync, gain=gain) # Update weights. phase.module.requires_grad_(False) with torch.autograd.profiler.record_function(phase.name + '_opt'): for param in phase.module.parameters(): if param.grad is not None: misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) phase.opt.step() if phase.end_event is not None: phase.end_event.record(torch.cuda.current_stream(device)) # Update G_ema. with torch.autograd.profiler.record_function('Gema'): ema_nimg = ema_kimg * 1000 if ema_rampup is not None: ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8)) for p_ema, p in zip(G_ema.parameters(), G.parameters()): p_ema.copy_(p.lerp(p_ema, ema_beta)) for b_ema, b in zip(G_ema.buffers(), G.buffers()): b_ema.copy_(b) # Update state. cur_nimg += batch_size batch_idx += 1 # Execute ADA heuristic. if (ada_stats is not None) and (batch_idx % ada_interval == 0): ada_stats.update() adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / ( ada_kimg * 1000) augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device))) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): continue tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_end_time = time.time() total_time = tick_end_time - start_time tick_time = tick_end_time - tick_start_time left_kimg = total_kimg - cur_nimg / 1000 left_sec = left_kimg * tick_time / tick_kimg finaltime = time.asctime(time.localtime(tick_end_time + left_sec)) msg_final = ' %ss left till %s ' % (shortime(left_sec), finaltime[11:16]) # Print status line, accumulating the same information in stats_collector. # tick_end_time = time.time() fields = [] fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<4d}"] fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<6.1f}"] fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', total_time)):<8s}"] fields += [msg_final] fields += [f"min/tick {training_stats.report0('Timing/sec_per_tick', tick_time / 60):<6.3g}"] fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', tick_time / tick_kimg):<7.3g}"] fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2 ** 30):<4.1f}"] fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2 ** 30):<4.1f}"] torch.cuda.reset_peak_memory_stats() fields += [f"aug {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"] # !!! fields += ["lr %.2g" % training_stats.report0('Progress/G_lrate', G_opt_kwargs.lr)] fields += ["%.2g" % training_stats.report0('Progress/D_lrate', D_opt_kwargs.lr)] training_stats.report0('Timing/total_hours', total_time / (60 * 60)) training_stats.report0('Timing/total_days', total_time / (24 * 60 * 60)) if rank == 0: print(' '.join(fields)) # Check for abort. if (not done) and (abort_fn is not None) and abort_fn(): done = True if rank == 0: print() print('Aborting...') # Save image snapshot. if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0): images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy() save_image_grid(images, os.path.join(run_dir, 'fake-%04d.%s' % (cur_nimg // 1000, ext)), drange=[-1, 1], grid_size=grid_size) # Save network snapshot. # snapshot_pkl = None if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs)) compact_data = dict(training_set_kwargs=dict(training_set_kwargs)) for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]: if module is not None: if num_gpus > 1: misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg') module = copy.deepcopy(module).eval().requires_grad_(False).cpu() snapshot_data[name] = module if name == 'G_ema': compact_data[name] = module # !!! del module # conserve memory # !!! pkl_snap = os.path.join(run_dir, '%s-%04d.pkl' % (savenames[0], cur_nimg // 1000)) pkl_run = os.path.join(run_dir, '%s-%04d.pkl' % (savenames[1], cur_nimg // 1000)) if rank == 0: with open(pkl_snap, 'wb') as f: pickle.dump(snapshot_data, f) with open(pkl_run, 'wb') as f: pickle.dump(compact_data, f) # Evaluate metrics. ''' if (snapshot_data is not None) and (len(metrics) > 0): if rank == 0: print('Evaluating metrics...') for metric in metrics: result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'], dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device) if rank == 0: metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl) stats_metrics.update(result_dict.results) ''' try: del snapshot_data # conserve memory except UnboundLocalError: pass # !!! try: del compact_data # conserve memory except UnboundLocalError: pass # Collect statistics. for phase in phases: value = [] if (phase.start_event is not None) and (phase.end_event is not None): phase.end_event.synchronize() value = phase.start_event.elapsed_time(phase.end_event) training_stats.report0('Timing/' + phase.name, value) stats_collector.update() stats_dict = stats_collector.as_dict() # Report stats to database if rank == 0: url = f'https://mnr6yzqr22jgywm-adw2.adb.eu-frankfurt-1.oraclecloudapps.com/ords/thesisproject/sg2/' \ f'report/{run_id}/{hostname}' requests.post(url, data=json.dumps(stats_dict)) # Update logs. timestamp = time.time() if stats_jsonl is not None: fields = dict(stats_dict, timestamp=timestamp) json_data = json.dumps(fields) # Write stats to file stats_jsonl.write(json_data + '\n') stats_jsonl.flush() # if stats_tfevents is not None: # global_step = int(cur_nimg / 1e3) # walltime = timestamp - start_time # for name, value in stats_dict.items(): # stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime) # for name, value in stats_metrics.items(): # stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime) # stats_tfevents.flush() if progress_fn is not None: progress_fn(cur_nimg // 1000, total_kimg) # Update state. cur_tick += 1 tick_start_nimg = cur_nimg tick_start_time = time.time() # maintenance_time = tick_start_time - tick_end_time if done: break # Done. if rank == 0: print() print('Exiting...')
def expand_seed(seeds, vector_size): result = [] for seed in seeds: rnd = np.random.RandomState(seed) result.append(rnd.randn(1, vector_size)) return result print("\nLATENT VECTOR WALK VIDEO GENERATOR made by JUSTIN GALLAGHER") print("DO NOT INCLUDE QUOTATIONS IN ANY FILE OR DIRECTORY PATH NAMES!\n") pkl_file = input('What is the .pkl file path: ') save_path = input('Save video to which directory: ') print(f'Loading networks from "{pkl_file}"...') device = torch.device('cuda') with dnnlib.util.open_url(pkl_file) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) vector_size = G.z_dim seeds = expand_seed( [8192+1,8192+9], vector_size) seed_list = [] seed_num = int(input("How many seeds would you like to input: ")) print('Begin inputting seeds now...') for i in range(0, seed_num): seed = int(input(f'Seed #{i+1}: ')) seed_list.append(seed) steps = int(input("Number of steps between seeds: ")) print('Creating temp directory for images...') temp = os.path.join(save_path, "temp-images")
def training_loop( run_dir='.', # Output directory. training_set_kwargs={}, # Options for training set. data_loader_kwargs={}, # Options for torch.utils.data.DataLoader. G_kwargs={}, # Options for generator network. D_kwargs={}, # Options for discriminator network. G_opt_kwargs={}, # Options for generator optimizer. D_opt_kwargs={}, # Options for discriminator optimizer. loss_kwargs={}, # Options for loss function. metrics=[], # Metrics to evaluate during training. random_seed=0, # Global random seed. num_gpus=1, # Number of GPUs participating in the training. rank=0, # Rank of the current process in [0, num_gpus[. batch_size=4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. batch_gpu=4, # Number of samples processed at a time by one GPU. ema_kimg=10, # Half-life of the exponential moving average (EMA) of generator weights. ema_rampup=0.05, # EMA ramp-up coefficient. None = no rampup. G_reg_interval=None, # How often to perform regularization for G? None = disable lazy regularization. D_reg_interval=16, # How often to perform regularization for D? None = disable lazy regularization. total_kimg=25000, # Total length of the training, measured in thousands of real images. kimg_per_tick=4, # Progress snapshot interval. image_snapshot_ticks=50, # How often to save image snapshots? None = disable. network_snapshot_ticks=50, # How often to save network snapshots? None = disable. resume_pkl=None, # Network pickle to resume training from. resume_kimg=0, # First kimg to report when resuming training. cudnn_benchmark=True, # Enable torch.backends.cudnn.benchmark? abort_fn=None, # Callback function for determining whether to abort training. Must return consistent results across ranks. progress_fn=None, # Callback function for updating training progress. Called for all ranks. restart_every=-1, # Time interval in seconds to exit code ): # Initialize. start_time = time.time() device = torch.device('cuda', rank) np.random.seed(random_seed * num_gpus + rank) torch.manual_seed(random_seed * num_gpus + rank) torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. conv2d_gradfix.enabled = True # Improves training speed. grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. __RESTART__ = torch.tensor( 0., device=device) # will be broadcasted to exit loop __CUR_NIMG__ = torch.tensor(resume_kimg * 1000, dtype=torch.long, device=device) __CUR_TICK__ = torch.tensor(0, dtype=torch.long, device=device) __BATCH_IDX__ = torch.tensor(0, dtype=torch.long, device=device) __PL_MEAN__ = torch.zeros([], device=device) best_fid = 9999 # Load training set. if rank == 0: print('Loading training set...') training_set = dnnlib.util.construct_class_by_name( **training_set_kwargs) # subclass of training.dataset.Dataset training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) training_set_iterator = iter( torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size // num_gpus, **data_loader_kwargs)) if rank == 0: print() print('Num images: ', len(training_set)) print('Image shape:', training_set.image_shape) print('Label shape:', training_set.label_shape) print() # Construct networks. if rank == 0: print('Constructing networks...') common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels) G = dnnlib.util.construct_class_by_name( **G_kwargs, **common_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module D = dnnlib.util.construct_class_by_name( **D_kwargs, **common_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module G_ema = copy.deepcopy(G).eval() # Check for existing checkpoint ckpt_pkl = None if restart_every > 0 and os.path.isfile(misc.get_ckpt_path(run_dir)): ckpt_pkl = resume_pkl = misc.get_ckpt_path(run_dir) # Resume from existing pickle. if (resume_pkl is not None) and (rank == 0): print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: resume_data = legacy.load_network_pkl(f) for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: misc.copy_params_and_buffers(resume_data[name], module, require_all=False) if ckpt_pkl is not None: # Load ticks __CUR_NIMG__ = resume_data['progress']['cur_nimg'].to(device) __CUR_TICK__ = resume_data['progress']['cur_tick'].to(device) __BATCH_IDX__ = resume_data['progress']['batch_idx'].to(device) __PL_MEAN__ = resume_data['progress'].get('pl_mean', torch.zeros( [])).to(device) best_fid = resume_data['progress'][ 'best_fid'] # only needed for rank == 0 del resume_data # Print network summary tables. if rank == 0: z = torch.empty([batch_gpu, G.z_dim], device=device) c = torch.empty([batch_gpu, G.c_dim], device=device) img = misc.print_module_summary(G, [z, c]) misc.print_module_summary(D, [img, c]) # Distribute across GPUs. if rank == 0: print(f'Distributing across {num_gpus} GPUs...') for module in [G, D, G_ema]: if module is not None and num_gpus > 1: for param in misc.params_and_buffers(module): torch.distributed.broadcast(param, src=0) # Setup training phases. if rank == 0: print('Setting up training phases...') loss = dnnlib.util.construct_class_by_name( device=device, G=G, G_ema=G_ema, D=D, **loss_kwargs) # subclass of training.loss.Loss phases = [] for name, module, opt_kwargs, reg_interval in [ ('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval) ]: if reg_interval is None: opt = dnnlib.util.construct_class_by_name( params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [ dnnlib.EasyDict(name=name + 'both', module=module, opt=opt, interval=1) ] else: # Lazy regularization. mb_ratio = reg_interval / (reg_interval + 1) opt_kwargs = dnnlib.EasyDict(opt_kwargs) opt_kwargs.lr = opt_kwargs.lr * mb_ratio opt_kwargs.betas = [beta**mb_ratio for beta in opt_kwargs.betas] opt = dnnlib.util.construct_class_by_name( module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [ dnnlib.EasyDict(name=name + 'main', module=module, opt=opt, interval=1) ] phases += [ dnnlib.EasyDict(name=name + 'reg', module=module, opt=opt, interval=reg_interval) ] for phase in phases: phase.start_event = None phase.end_event = None if rank == 0: phase.start_event = torch.cuda.Event(enable_timing=True) phase.end_event = torch.cuda.Event(enable_timing=True) # Export sample images. grid_size = None grid_z = None grid_c = None if rank == 0: print('Exporting sample images...') grid_size, images, labels = setup_snapshot_image_grid( training_set=training_set) save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0, 255], grid_size=grid_size) grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) grid_c = torch.from_numpy(labels).to(device).split(batch_gpu) images = torch.cat([ G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c) ]).numpy() save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1, 1], grid_size=grid_size) # Initialize logs. if rank == 0: print('Initializing logs...') stats_collector = training_stats.Collector(regex='.*') stats_metrics = dict() stats_jsonl = None stats_tfevents = None if rank == 0: stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt') try: import torch.utils.tensorboard as tensorboard stats_tfevents = tensorboard.SummaryWriter(run_dir) except ImportError as err: print('Skipping tfevents export:', err) # Train. if rank == 0: print(f'Training for {total_kimg} kimg...') print() if num_gpus > 1: # broadcast loaded states to all torch.distributed.broadcast(__CUR_NIMG__, 0) torch.distributed.broadcast(__CUR_TICK__, 0) torch.distributed.broadcast(__BATCH_IDX__, 0) torch.distributed.broadcast(__PL_MEAN__, 0) torch.distributed.barrier() # ensure all processes received this info cur_nimg = __CUR_NIMG__.item() cur_tick = __CUR_TICK__.item() tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - start_time batch_idx = __BATCH_IDX__.item() if progress_fn is not None: progress_fn(cur_nimg // 1000, total_kimg) if hasattr(loss, 'pl_mean'): loss.pl_mean.copy_(__PL_MEAN__) while True: with torch.autograd.profiler.record_function('data_fetch'): phase_real_img, phase_real_c = next(training_set_iterator) phase_real_img = ( phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) phase_real_c = phase_real_c.to(device).split(batch_gpu) all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device) all_gen_z = [ phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size) ] all_gen_c = [ training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size) ] all_gen_c = torch.from_numpy( np.stack(all_gen_c)).pin_memory().to(device) all_gen_c = [ phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size) ] # Execute training phases. for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c): if batch_idx % phase.interval != 0: continue if phase.start_event is not None: phase.start_event.record(torch.cuda.current_stream(device)) # Accumulate gradients. phase.opt.zero_grad(set_to_none=True) phase.module.requires_grad_(True) if phase.name in ['Dmain', 'Dboth', 'Dreg']: phase.module.feature_network.requires_grad_(False) for real_img, real_c, gen_z, gen_c in zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c): loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, gain=phase.interval, cur_nimg=cur_nimg) phase.module.requires_grad_(False) # Update weights. with torch.autograd.profiler.record_function(phase.name + '_opt'): params = [ param for param in phase.module.parameters() if param.grad is not None ] if len(params) > 0: flat = torch.cat( [param.grad.flatten() for param in params]) if num_gpus > 1: torch.distributed.all_reduce(flat) flat /= num_gpus misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) grads = flat.split([param.numel() for param in params]) for param, grad in zip(params, grads): param.grad = grad.reshape(param.shape) phase.opt.step() # Phase done. if phase.end_event is not None: phase.end_event.record(torch.cuda.current_stream(device)) # Update G_ema. with torch.autograd.profiler.record_function('Gema'): ema_nimg = ema_kimg * 1000 if ema_rampup is not None: ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) ema_beta = 0.5**(batch_size / max(ema_nimg, 1e-8)) for p_ema, p in zip(G_ema.parameters(), G.parameters()): p_ema.copy_(p.lerp(p_ema, ema_beta)) for b_ema, b in zip(G_ema.buffers(), G.buffers()): b_ema.copy_(b) # Update state. cur_nimg += batch_size batch_idx += 1 # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if (not done) and (cur_tick != 0) and ( cur_nimg < tick_start_nimg + kimg_per_tick * 1000): continue # Print status line, accumulating the same information in training_stats. tick_end_time = time.time() fields = [] fields += [ f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}" ] fields += [ f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}" ] fields += [ f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}" ] fields += [ f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}" ] fields += [ f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}" ] fields += [ f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}" ] fields += [ f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" ] fields += [ f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}" ] fields += [ f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}" ] torch.cuda.reset_peak_memory_stats() training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60)) training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60)) if rank == 0: print(' '.join(fields)) # Check for abort. if (not done) and (abort_fn is not None) and abort_fn(): done = True if rank == 0: print() print('Aborting...') # Check for restart. if (rank == 0) and (restart_every > 0) and (time.time() - start_time > restart_every): print('Restart job...') __RESTART__ = torch.tensor(1., device=device) if num_gpus > 1: torch.distributed.broadcast(__RESTART__, 0) if __RESTART__: done = True print(f'Process {rank} leaving...') if num_gpus > 1: torch.distributed.barrier() # Save image snapshot. if (rank == 0) and (image_snapshot_ticks is not None) and ( done or cur_tick % image_snapshot_ticks == 0): images = torch.cat([ G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c) ]).numpy() save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1, 1], grid_size=grid_size) # Save network snapshot. snapshot_pkl = None snapshot_data = None if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): snapshot_data = dict(G=G, D=D, G_ema=G_ema, training_set_kwargs=dict(training_set_kwargs)) for key, value in snapshot_data.items(): if isinstance(value, torch.nn.Module): snapshot_data[key] = value del value # conserve memory # Save Checkpoint if needed if (rank == 0) and (restart_every > 0) and ( network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): snapshot_pkl = misc.get_ckpt_path(run_dir) # save as tensors to avoid error for multi GPU snapshot_data['progress'] = { 'cur_nimg': torch.LongTensor([cur_nimg]), 'cur_tick': torch.LongTensor([cur_tick]), 'batch_idx': torch.LongTensor([batch_idx]), 'best_fid': best_fid, } if hasattr(loss, 'pl_mean'): snapshot_data['progress']['pl_mean'] = loss.pl_mean.cpu() with open(snapshot_pkl, 'wb') as f: pickle.dump(snapshot_data, f) # Evaluate metrics. # if (snapshot_data is not None) and (len(metrics) > 0): if cur_tick and (snapshot_data is not None) and (len(metrics) > 0): if rank == 0: print('Evaluating metrics...') for metric in metrics: result_dict = metric_main.calc_metric( metric=metric, G=snapshot_data['G_ema'], run_dir=run_dir, cur_nimg=cur_nimg, dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device) if rank == 0: metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl) stats_metrics.update(result_dict.results) # save best fid ckpt snapshot_pkl = os.path.join(run_dir, f'best_model.pkl') cur_nimg_txt = os.path.join(run_dir, f'best_nimg.txt') if rank == 0: if 'fid50k_full' in stats_metrics and stats_metrics[ 'fid50k_full'] < best_fid: best_fid = stats_metrics['fid50k_full'] with open(snapshot_pkl, 'wb') as f: dill.dump(snapshot_data, f) # save curr iteration number (directly saving it to pkl leads to problems with multi GPU) with open(cur_nimg_txt, 'w') as f: f.write(str(cur_nimg)) del snapshot_data # conserve memory # Collect statistics. for phase in phases: value = [] if (phase.start_event is not None) and (phase.end_event is not None) and \ not (phase.start_event.cuda_event == 0 and phase.end_event.cuda_event == 0): # Both events were not initialized yet, can happen with restart phase.end_event.synchronize() value = phase.start_event.elapsed_time(phase.end_event) training_stats.report0('Timing/' + phase.name, value) stats_collector.update() stats_dict = stats_collector.as_dict() # Update logs. timestamp = time.time() if stats_jsonl is not None: fields = dict(stats_dict, timestamp=timestamp) stats_jsonl.write(json.dumps(fields) + '\n') stats_jsonl.flush() if stats_tfevents is not None: global_step = int(cur_nimg / 1e3) walltime = timestamp - start_time for name, value in stats_dict.items(): stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime) for name, value in stats_metrics.items(): stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime) stats_tfevents.flush() if progress_fn is not None: progress_fn(cur_nimg // 1000, total_kimg) # Update state. cur_tick += 1 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - tick_end_time if done: break # Done. if rank == 0: print() print('Exiting...')
def generate_style_mix( network_pkl: str, row_seeds: List[int], col_seeds: List[int], col_styles: List[int], truncation_psi: float, noise_mode: str, outdir: str, ): """Generate images using pretrained network pickle. Examples: \b python style_mixing.py --outdir=out --rows=85,100,75,458,1500 --cols=55,821,1789,293 \\ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl """ print('Loading networks from "%s"...' % network_pkl) device = torch.device("cuda") with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)["G_ema"].to(device) # type: ignore os.makedirs(outdir, exist_ok=True) print("Generating W vectors...") all_seeds = list(set(row_seeds + col_seeds)) all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds]) all_w = G.mapping(torch.from_numpy(all_z).to(device), None) w_avg = G.mapping.w_avg all_w = w_avg + (all_w - w_avg) * truncation_psi w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))} print("Generating images...") all_images = G.synthesis(all_w, noise_mode=noise_mode) all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy() image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))} print("Generating style-mixed images...") for row_seed in row_seeds: for col_seed in col_seeds: w = w_dict[row_seed].clone() w[col_styles] = w_dict[col_seed][col_styles] image = G.synthesis(w[np.newaxis], noise_mode=noise_mode) image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) image_dict[(row_seed, col_seed)] = image[0].cpu().numpy() print("Saving images...") os.makedirs(outdir, exist_ok=True) for (row_seed, col_seed), image in image_dict.items(): PIL.Image.fromarray(image, "RGB").save(f"{outdir}/{row_seed}-{col_seed}.png") print("Saving image grid...") W = G.img_resolution H = G.img_resolution canvas = PIL.Image.new("RGB", (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), "black") for row_idx, row_seed in enumerate([0] + row_seeds): for col_idx, col_seed in enumerate([0] + col_seeds): if row_idx == 0 and col_idx == 0: continue key = (row_seed, col_seed) if row_idx == 0: key = (col_seed, col_seed) if col_idx == 0: key = (row_seed, row_seed) canvas.paste(PIL.Image.fromarray(image_dict[key], "RGB"), (W * col_idx, H * row_idx)) canvas.save(f"{outdir}/grid.png")
def main(): if a.vector_dir is not None: if a.vector_dir.endswith('/') or a.vector_dir.endswith('\\'): a.vector_dir = a.vector_dir[:-1] os.makedirs(a.out_dir, exist_ok=True) device = torch.device('cuda') global Gs, use_d, custom # setup generator Gs_kwargs = dnnlib.EasyDict() Gs_kwargs.verbose = a.verbose Gs_kwargs.size = a.size Gs_kwargs.scale_type = a.scale_type # load base or custom network pkl_name = osp.splitext(a.model)[0] if '.pkl' in a.model.lower(): custom = False print(' .. Gs from pkl ..', basename(a.model)) else: custom = True print(' .. Gs custom ..', basename(a.model)) with dnnlib.util.open_url(pkl_name + '.pkl') as f: Gs = legacy.load_network_pkl(f, custom=custom, **Gs_kwargs)['G_ema'].to( device) # type: ignore # load directions if a.vector_dir is not None: directions = [] vector_list = file_list(a.vector_dir, 'npy') for v in vector_list: direction = load_latents(v) if len(direction.shape) == 2: direction = np.expand_dims(direction, 0) directions.append(direction) directions = np.concatenate(directions)[:, np.newaxis] # [frm,1,18,512] else: print(' No vectors found') exit() if len(direction[0].shape) > 1 and direction[0].shape[0] > 1: use_d = True print(' directions', directions.shape, 'using d' if use_d else 'using w') directions = torch.from_numpy(directions).to(device) # latent direction range lrange = [-0.5, 0.5] # load saved latents if a.base_lat is not None: base_latent = load_latents(a.base_lat) base_latent = torch.from_numpy(base_latent).to(device) else: print(' No NPY input given, making random') base_latent = np.random.randn(1, Gs.z_dim) if use_d: base_latent = Gs.mapping(base_latent, None) # [frm,18,512] pbar = ProgressBar(len(directions)) for i, direction in enumerate(directions): make_loop(base_latent, direction, lrange, a.fstep * 2, a.fstep * 2 * i) pbar.upd()
def training_loop( run_dir='.', # Output directory. training_set_kwargs={}, # Options for training set. data_loader_kwargs={}, # Options for torch.utils.data.DataLoader. G_kwargs={}, # Options for generator network. D_kwargs={}, # Options for discriminator network. D2_kwargs={}, # Options for discriminator network. G_opt_kwargs={}, # Options for generator optimizer. D_opt_kwargs={}, # Options for discriminator optimizer. augment_kwargs=None, # Options for augmentation pipeline. None = disable. loss_kwargs={}, # Options for loss function. metrics=[], # Metrics to evaluate during training. random_seed=0, # Global random seed. num_gpus=1, # Number of GPUs participating in the training. rank=0, # Rank of the current process in [0, num_gpus[. batch_size=4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. batch_gpu=4, # Number of samples processed at a time by one GPU. ema_kimg=10, # Half-life of the exponential moving average (EMA) of generator weights. ema_rampup=None, # EMA ramp-up coefficient. G_reg_interval=4, # How often to perform regularization for G? None = disable lazy regularization. D_reg_interval=16, # How often to perform regularization for D? None = disable lazy regularization. augment_p=0, # Initial value of augmentation probability. ada_target=None, # ADA target value. None = fixed p. ada_interval=4, # How often to perform ADA adjustment? ada_kimg=500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit. total_kimg=25000, # Total length of the training, measured in thousands of real images. kimg_per_tick=4, # Progress snapshot interval. image_snapshot_ticks=50, # How often to save image snapshots? None = disable. network_snapshot_ticks=50, # How often to save network snapshots? None = disable. resume_pkl=None, # Network pickle to resume training from. cudnn_benchmark=True, # Enable torch.backends.cudnn.benchmark? abort_fn=None, # Callback function for determining whether to abort training. Must return consistent results across ranks. progress_fn=None, # Callback function for updating training progress. Called for all ranks. obake=None, # Obake training: <bool>, default = False ): # Initialize. start_time = time.time() device = torch.device('cuda', rank) np.random.seed(random_seed * num_gpus + rank) torch.manual_seed(random_seed * num_gpus + rank) torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. conv2d_gradfix.enabled = True # Improves training speed. grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. # Load training set. if rank == 0: print('Loading training set...') training_set = dnnlib.util.construct_class_by_name( **training_set_kwargs) # subclass of training.dataset.Dataset training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) training_set_iterator = iter( torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size // num_gpus, **data_loader_kwargs)) if rank == 0: print() print('Num images: ', len(training_set)) print('Image shape:', training_set.image_shape) print('Label shape:', training_set.label_shape) print() # Construct networks. if rank == 0: print('Constructing networks...') common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels) G = dnnlib.util.construct_class_by_name( **G_kwargs, **common_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module D = dnnlib.util.construct_class_by_name( **D_kwargs, **common_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module G_ema = copy.deepcopy(G).eval() if obake is not None: D_mtcnn = MTCNN(image_size=D2_kwargs.mtcnn_output_size, margin=D2_kwargs.mtcnn_output_margin, thresholds=D2_kwargs.mtcnn_thresholds) D_face = InceptionResnetV1(pretrained=D2_kwargs.resnet_type).eval() # Resume from existing pickle. if (resume_pkl is not None) and (rank == 0): print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: resume_data = legacy.load_network_pkl(f) for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: misc.copy_params_and_buffers(resume_data[name], module, require_all=False) # Print network summary tables. if rank == 0: z = torch.empty([batch_gpu, G.z_dim], device=device) c = torch.empty([batch_gpu, G.c_dim], device=device) img = misc.print_module_summary(G, [z, c]) misc.print_module_summary(D, [img, c]) # Setup augmentation. if rank == 0: print('Setting up augmentation...') augment_pipe = None ada_stats = None if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None): augment_pipe = dnnlib.util.construct_class_by_name( **augment_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module augment_pipe.p.copy_(torch.as_tensor(augment_p)) if ada_target is not None: ada_stats = training_stats.Collector(regex='Loss/signs/real') # Distribute across GPUs. if rank == 0: print(f'Distributing across {num_gpus} GPUs...') ddp_modules = dict() if obake is not None: for name, module in [('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), ('D_mtcnn', D_mtcnn), ('D_face', D_face), (None, G_ema), ('augment_pipe', augment_pipe)]: if (num_gpus > 1) and (module is not None) and len( list(module.parameters())) != 0: module.requires_grad_(True) module = torch.nn.parallel.DistributedDataParallel( module, device_ids=[device], broadcast_buffers=False) module.requires_grad_(False) if name is not None: ddp_modules[name] = module else: for name, module in [('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema), ('augment_pipe', augment_pipe)]: if (num_gpus > 1) and (module is not None) and len( list(module.parameters())) != 0: module.requires_grad_(True) module = torch.nn.parallel.DistributedDataParallel( module, device_ids=[device], broadcast_buffers=False) module.requires_grad_(False) if name is not None: ddp_modules[name] = module # Setup training phases. if rank == 0: print('Setting up training phases...') loss = dnnlib.util.construct_class_by_name( device=device, **ddp_modules, **loss_kwargs) # subclass of training.loss.Loss phases = [] for name, module, opt_kwargs, reg_interval in [ ('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval) ]: if reg_interval is None: opt = dnnlib.util.construct_class_by_name( params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [ dnnlib.EasyDict(name=name + 'both', module=module, opt=opt, interval=1) ] else: # Lazy regularization. mb_ratio = reg_interval / (reg_interval + 1) opt_kwargs = dnnlib.EasyDict(opt_kwargs) opt_kwargs.lr = opt_kwargs.lr * mb_ratio opt_kwargs.betas = [beta**mb_ratio for beta in opt_kwargs.betas] opt = dnnlib.util.construct_class_by_name( module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [ dnnlib.EasyDict(name=name + 'main', module=module, opt=opt, interval=1) ] phases += [ dnnlib.EasyDict(name=name + 'reg', module=module, opt=opt, interval=reg_interval) ] for phase in phases: phase.start_event = None phase.end_event = None if rank == 0: phase.start_event = torch.cuda.Event(enable_timing=True) phase.end_event = torch.cuda.Event(enable_timing=True) # Export sample images. grid_size = None grid_z = None grid_c = None if rank == 0: print('Exporting sample images...') grid_size, images, labels = setup_snapshot_image_grid( training_set=training_set) save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0, 255], grid_size=grid_size) grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) grid_c = torch.from_numpy(labels).to(device).split(batch_gpu) images = torch.cat([ G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c) ]).numpy() save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1, 1], grid_size=grid_size) # Initialize logs. if rank == 0: print('Initializing logs...') stats_collector = training_stats.Collector(regex='.*') stats_metrics = dict() stats_jsonl = None stats_tfevents = None if rank == 0: stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt') try: import torch.utils.tensorboard as tensorboard stats_tfevents = tensorboard.SummaryWriter(run_dir) except ImportError as err: print('Skipping tfevents export:', err) # Train. if rank == 0: print(f'Training for {total_kimg} kimg...') print() cur_nimg = 0 cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - start_time batch_idx = 0 if progress_fn is not None: progress_fn(0, total_kimg) while True: # Fetch training data. with torch.autograd.profiler.record_function('data_fetch'): phase_real_img, phase_real_c = next(training_set_iterator) phase_real_img = ( phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) phase_real_c = phase_real_c.to(device).split(batch_gpu) all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device) all_gen_z = [ phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size) ] all_gen_c = [ training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size) ] all_gen_c = torch.from_numpy( np.stack(all_gen_c)).pin_memory().to(device) all_gen_c = [ phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size) ] # Execute training phases. for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c): if batch_idx % phase.interval != 0: continue # Initialize gradient accumulation. if phase.start_event is not None: phase.start_event.record(torch.cuda.current_stream(device)) phase.opt.zero_grad(set_to_none=True) phase.module.requires_grad_(True) # Accumulate gradients over multiple rounds. for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate( zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c)): sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1) gain = phase.interval print(phase.name) loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, sync=sync, gain=gain) # Update weights. phase.module.requires_grad_(False) with torch.autograd.profiler.record_function(phase.name + '_opt'): for param in phase.module.parameters(): if param.grad is not None: misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) phase.opt.step() if phase.end_event is not None: phase.end_event.record(torch.cuda.current_stream(device)) # Update G_ema. with torch.autograd.profiler.record_function('Gema'): ema_nimg = ema_kimg * 1000 if ema_rampup is not None: ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) ema_beta = 0.5**(batch_size / max(ema_nimg, 1e-8)) for p_ema, p in zip(G_ema.parameters(), G.parameters()): p_ema.copy_(p.lerp(p_ema, ema_beta)) for b_ema, b in zip(G_ema.buffers(), G.buffers()): b_ema.copy_(b) # Update state. cur_nimg += batch_size batch_idx += 1 # Execute ADA heuristic. if (ada_stats is not None) and (batch_idx % ada_interval == 0): ada_stats.update() adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * ( batch_size * ada_interval) / (ada_kimg * 1000) augment_pipe.p.copy_( (augment_pipe.p + adjust).max(misc.constant(0, device=device))) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if (not done) and (cur_tick != 0) and ( cur_nimg < tick_start_nimg + kimg_per_tick * 1000): continue # Print status line, accumulating the same information in stats_collector. tick_end_time = time.time() fields = [] fields += [ f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}" ] fields += [ f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}" ] fields += [ f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}" ] fields += [ f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}" ] fields += [ f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}" ] fields += [ f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}" ] fields += [ f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" ] fields += [ f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}" ] torch.cuda.reset_peak_memory_stats() fields += [ f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}" ] training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60)) training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60)) if rank == 0: print(' '.join(fields)) # Check for abort. if (not done) and (abort_fn is not None) and abort_fn(): done = True if rank == 0: print() print('Aborting...') # Save image snapshot. if (rank == 0) and (image_snapshot_ticks is not None) and ( done or cur_tick % image_snapshot_ticks == 0): images = torch.cat([ G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c) ]).numpy() save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1, 1], grid_size=grid_size) # Save network snapshot. snapshot_pkl = None snapshot_data = None if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs)) for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]: if module is not None: if num_gpus > 1: misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg') module = copy.deepcopy(module).eval().requires_grad_( False).cpu() snapshot_data[name] = module del module # conserve memory snapshot_pkl = os.path.join( run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl') if rank == 0: with open(snapshot_pkl, 'wb') as f: pickle.dump(snapshot_data, f) # Evaluate metrics. if (snapshot_data is not None) and (len(metrics) > 0): if rank == 0: print('Evaluating metrics...') for metric in metrics: result_dict = metric_main.calc_metric( metric=metric, G=snapshot_data['G_ema'], dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device) if rank == 0: metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl) stats_metrics.update(result_dict.results) del snapshot_data # conserve memory # Collect statistics. for phase in phases: value = [] if (phase.start_event is not None) and (phase.end_event is not None): phase.end_event.synchronize() value = phase.start_event.elapsed_time(phase.end_event) training_stats.report0('Timing/' + phase.name, value) stats_collector.update() stats_dict = stats_collector.as_dict() # Update logs. timestamp = time.time() if stats_jsonl is not None: fields = dict(stats_dict, timestamp=timestamp) stats_jsonl.write(json.dumps(fields) + '\n') stats_jsonl.flush() if stats_tfevents is not None: global_step = int(cur_nimg / 1e3) walltime = timestamp - start_time for name, value in stats_dict.items(): stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime) for name, value in stats_metrics.items(): stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime) stats_tfevents.flush() if progress_fn is not None: progress_fn(cur_nimg // 1000, total_kimg) # Update state. cur_tick += 1 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - tick_end_time if done: break # Done. if rank == 0: print() print('Exiting...')
def forward(self, inputs): return [inputs[0][:, :, self.ystart:self.yend, self.xstart:self.xend]] cv2.dnn_registerLayer('Crop', CropLayer) pretrained_model = "models/edge_detection/hed_pretrained_bsds.caffemodel" model_def = "models/edge_detection/deploy.prototxt" edge_net = cv2.dnn.readNetFromCaffe(model_def, pretrained_model) truncation_psi = 0.4 seed = 150 device = torch.device('cuda') network_pkl = "network-snapshot-000712.pkl" with dnnlib.util.open_url(network_pkl) as fp: G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to( device) # type: ignore G = copy.deepcopy(G).eval().requires_grad_(False).to(device) #z = torch.from_numpy(np.random.RandomState().randn(1, G.z_dim)).to(device) #img = G(z, None, truncation_psi=truncation_psi) #img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) # Load VGG16 feature detector. vgg16 = torch.hub.load('pytorch/vision:v0.6.0', 'vgg16', pretrained=True) # Replace last layer vgg16.classifier[-1] = nn.Linear(4096, 512) vgg16.classifier load(vgg16) vgg16 = vgg16.train().to(device)